feat: Add base project structure with API and web components
This commit is contained in:
0
api/app/core/rag/__init__.py
Normal file
0
api/app/core/rag/__init__.py
Normal file
0
api/app/core/rag/app/__init__.py
Normal file
0
api/app/core/rag/app/__init__.py
Normal file
42
api/app/core/rag/app/audio.py
Normal file
42
api/app/core/rag/app/audio.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
from app.core.rag.nlp import rag_tokenizer, tokenize
|
||||
|
||||
|
||||
def chunk(filename, binary, lang, callback=None, seq2txt_mdl=None, **kwargs):
|
||||
doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
|
||||
# is it English
|
||||
eng = lang.lower() == "english" # is_english(sections)
|
||||
try:
|
||||
_, ext = os.path.splitext(filename)
|
||||
if not ext:
|
||||
raise RuntimeError("No extension detected.")
|
||||
|
||||
if ext not in [".da", ".wave", ".wav", ".mp3", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma", ".realaudio", ".vqf", ".oggvorbis", ".ape"]:
|
||||
raise RuntimeError(f"Extension {ext} is not supported yet.")
|
||||
|
||||
tmp_path = ""
|
||||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmpf:
|
||||
tmpf.write(binary)
|
||||
tmpf.flush()
|
||||
tmp_path = os.path.abspath(tmpf.name)
|
||||
|
||||
callback(0.1, "USE Sequence2Txt LLM to transcription the audio")
|
||||
ans = seq2txt_mdl.transcription(tmp_path)
|
||||
callback(0.8, "Sequence2Txt LLM respond: %s ..." % ans[:32])
|
||||
|
||||
tokenize(doc, ans, eng)
|
||||
return [doc]
|
||||
except Exception as e:
|
||||
callback(prog=-1, msg=str(e))
|
||||
finally:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
170
api/app/core/rag/app/book.py
Normal file
170
api/app/core/rag/app/book.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import logging
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
from app.core.rag.deepdoc.parser.utils import get_text
|
||||
from . import naive
|
||||
from .naive import by_plaintext, PARSERS
|
||||
from app.core.rag.nlp import bullets_category, is_english,remove_contents_table, \
|
||||
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \
|
||||
tokenize_chunks
|
||||
from app.core.rag.nlp import rag_tokenizer
|
||||
from app.core.rag.deepdoc.parser import PdfParser, HtmlParser
|
||||
from app.core.rag.deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
from timeit import default_timer as timer
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(
|
||||
filename if not binary else binary,
|
||||
zoomin,
|
||||
from_page,
|
||||
to_page,
|
||||
callback)
|
||||
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin)
|
||||
callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
logging.debug("layouts: {}".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._table_transformer_job(zoomin)
|
||||
callback(0.68, "Table analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._text_merge()
|
||||
tbls = self._extract_table_figure(True, zoomin, True, True)
|
||||
self._naive_vertical_merge()
|
||||
self._filter_forpages()
|
||||
self._merge_with_same_bullet()
|
||||
callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start))
|
||||
|
||||
return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", ""))
|
||||
for b in self.boxes], tbls
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Supported file formats are docx, pdf, txt.
|
||||
Since a book is long and not all the parts are useful, if it's a PDF,
|
||||
please setup the page ranges for every book in order eliminate negative effects and save elapsed computing time.
|
||||
"""
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
pdf_parser = None
|
||||
sections, tbls = [], []
|
||||
if re.search(r"\.docx$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
doc_parser = naive.Docx()
|
||||
# TODO: table of contents need to be removed
|
||||
sections, tbls = doc_parser(
|
||||
filename, binary=binary, from_page=from_page, to_page=to_page)
|
||||
remove_contents_table(sections, eng=is_english(
|
||||
random_choices([t for t, _ in sections], k=200)))
|
||||
tbls=vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
|
||||
# tbls = [((None, lns), None) for lns in tbls]
|
||||
sections=[(item[0],item[1] if item[1] is not None else "") for item in sections if not isinstance(item[1], Image.Image)]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
|
||||
if isinstance(layout_recognizer, bool):
|
||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
||||
|
||||
name = layout_recognizer.strip().lower()
|
||||
parser = PARSERS.get(name, by_plaintext)
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
sections, tables, pdf_parser = parser(
|
||||
filename = filename,
|
||||
binary = binary,
|
||||
from_page = from_page,
|
||||
to_page = to_page,
|
||||
lang = lang,
|
||||
callback = callback,
|
||||
pdf_cls = Pdf,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not sections and not tables:
|
||||
return []
|
||||
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
callback(0.8, "Finish parsing.")
|
||||
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
sections = txt.split("\n")
|
||||
sections = [(line, "") for line in sections if line]
|
||||
remove_contents_table(sections, eng=is_english(
|
||||
random_choices([t for t, _ in sections], k=200)))
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
sections = HtmlParser()(filename, binary)
|
||||
sections = [(line, "") for line in sections if line]
|
||||
remove_contents_table(sections, eng=is_english(
|
||||
random_choices([t for t, _ in sections], k=200)))
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
binary = BytesIO(binary)
|
||||
doc_parsed = parser.from_buffer(binary)
|
||||
sections = doc_parsed['content'].split('\n')
|
||||
sections = [(line, "") for line in sections if line]
|
||||
remove_contents_table(sections, eng=is_english(
|
||||
random_choices([t for t, _ in sections], k=200)))
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"file type not supported yet(doc, docx, pdf, txt supported)")
|
||||
|
||||
make_colon_as_title(sections)
|
||||
bull = bullets_category(
|
||||
[t for t in random_choices([t for t, _ in sections], k=100)])
|
||||
if bull >= 0:
|
||||
chunks = ["\n".join(ck)
|
||||
for ck in hierarchical_merge(bull, sections, 5)]
|
||||
else:
|
||||
sections = [s.split("@") for s, _ in sections]
|
||||
sections = [(pr[0], "@" + pr[1]) if len(pr) == 2 else (pr[0], '') for pr in sections ]
|
||||
chunks = naive_merge(
|
||||
sections, kwargs.get(
|
||||
"chunk_token_num", 256), kwargs.get(
|
||||
"delimer", "\n。;!?"))
|
||||
|
||||
# is it English
|
||||
# is_english(random_choices([t for t, _ in sections], k=218))
|
||||
eng = lang.lower() == "english"
|
||||
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy)
|
||||
219
api/app/core/rag/app/laws.py
Normal file
219
api/app/core/rag/app/laws.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import logging
|
||||
import re
|
||||
from io import BytesIO
|
||||
from docx import Document
|
||||
|
||||
from app.core.rag.common.constants import ParserType
|
||||
from app.core.rag.deepdoc.parser.utils import get_text
|
||||
from app.core.rag.nlp import bullets_category, remove_contents_table, \
|
||||
make_colon_as_title, tokenize_chunks, docx_question_level, tree_merge
|
||||
from app.core.rag.nlp import rag_tokenizer, Node
|
||||
from app.core.rag.deepdoc.parser import PdfParser, DocxParser, HtmlParser
|
||||
from app.core.rag.app.naive import by_plaintext, PARSERS
|
||||
|
||||
|
||||
|
||||
|
||||
class Docx(DocxParser):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __clean(self, line):
|
||||
line = re.sub(r"\u3000", " ", line).strip()
|
||||
return line
|
||||
|
||||
def old_call(self, filename, binary=None, from_page=0, to_page=100000):
|
||||
self.doc = Document(
|
||||
filename) if not binary else Document(BytesIO(binary))
|
||||
pn = 0
|
||||
lines = []
|
||||
for p in self.doc.paragraphs:
|
||||
if pn > to_page:
|
||||
break
|
||||
if from_page <= pn < to_page and p.text.strip():
|
||||
lines.append(self.__clean(p.text))
|
||||
for run in p.runs:
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
continue
|
||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||
pn += 1
|
||||
return [line for line in lines if line]
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
|
||||
self.doc = Document(
|
||||
filename) if not binary else Document(BytesIO(binary))
|
||||
pn = 0
|
||||
lines = []
|
||||
level_set = set()
|
||||
bull = bullets_category([p.text for p in self.doc.paragraphs])
|
||||
for p in self.doc.paragraphs:
|
||||
if pn > to_page:
|
||||
break
|
||||
question_level, p_text = docx_question_level(p, bull)
|
||||
if not p_text.strip("\n"):
|
||||
continue
|
||||
lines.append((question_level, p_text))
|
||||
level_set.add(question_level)
|
||||
for run in p.runs:
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
continue
|
||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||
pn += 1
|
||||
|
||||
sorted_levels = sorted(level_set)
|
||||
|
||||
h2_level = sorted_levels[1] if len(sorted_levels) > 1 else 1
|
||||
h2_level = sorted_levels[-2] if h2_level == sorted_levels[-1] and len(sorted_levels) > 2 else h2_level
|
||||
|
||||
root = Node(level=0, depth=h2_level, texts=[])
|
||||
root.build_tree(lines)
|
||||
|
||||
return [element for element in root.get_tree() if element]
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'''
|
||||
question:{self.question},
|
||||
answer:{self.answer},
|
||||
level:{self.level},
|
||||
childs:{self.childs}
|
||||
'''
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
self.model_speciess = ParserType.LAWS.value
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
from timeit import default_timer as timer
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(
|
||||
filename if not binary else binary,
|
||||
zoomin,
|
||||
from_page,
|
||||
to_page,
|
||||
callback
|
||||
)
|
||||
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin)
|
||||
callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
logging.debug("layouts:".format(
|
||||
))
|
||||
self._naive_vertical_merge()
|
||||
|
||||
callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start))
|
||||
|
||||
return [(b["text"], self._line_tag(b, zoomin))
|
||||
for b in self.boxes], None
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Supported file formats are docx, pdf, txt.
|
||||
"""
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
pdf_parser = None
|
||||
sections = []
|
||||
# is it English
|
||||
eng = lang.lower() == "english" # is_english(sections)
|
||||
|
||||
if re.search(r"\.docx$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
chunks = Docx()(filename, binary)
|
||||
callback(0.7, "Finish parsing.")
|
||||
return tokenize_chunks(chunks, doc, eng, None)
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
|
||||
if isinstance(layout_recognizer, bool):
|
||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
||||
|
||||
name = layout_recognizer.strip().lower()
|
||||
parser = PARSERS.get(name, by_plaintext)
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
raw_sections, tables, pdf_parser = parser(
|
||||
filename = filename,
|
||||
binary = binary,
|
||||
from_page = from_page,
|
||||
to_page = to_page,
|
||||
lang = lang,
|
||||
callback = callback,
|
||||
pdf_cls = Pdf,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not raw_sections and not tables:
|
||||
return []
|
||||
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
for txt, poss in raw_sections:
|
||||
sections.append(txt + poss)
|
||||
|
||||
callback(0.8, "Finish parsing.")
|
||||
elif re.search(r"\.(txt|md|markdown|mdx)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
sections = txt.split("\n")
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
sections = HtmlParser()(filename, binary)
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
binary = BytesIO(binary)
|
||||
doc_parsed = parser.from_buffer(binary)
|
||||
sections = doc_parsed['content'].split('\n')
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"file type not supported yet(doc, docx, pdf, txt supported)")
|
||||
|
||||
|
||||
# Remove 'Contents' part
|
||||
remove_contents_table(sections, eng)
|
||||
|
||||
make_colon_as_title(sections)
|
||||
bull = bullets_category(sections)
|
||||
res = tree_merge(bull, sections, 2)
|
||||
|
||||
|
||||
if not res:
|
||||
callback(0.99, "No chunk parsed out.")
|
||||
|
||||
return tokenize_chunks(res, doc, eng, pdf_parser)
|
||||
|
||||
# chunks = hierarchical_merge(bull, sections, 5)
|
||||
# return tokenize_chunks(["\n".join(ck)for ck in chunks], doc, eng, pdf_parser)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
114
api/app/core/rag/app/mail.py
Normal file
114
api/app/core/rag/app/mail.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import logging
|
||||
from email import policy
|
||||
from email.parser import BytesParser
|
||||
from .naive import chunk as naive_chunk
|
||||
import re
|
||||
from app.core.rag.nlp import rag_tokenizer, naive_merge, tokenize_chunks
|
||||
from app.core.rag.deepdoc.parser import HtmlParser, TxtParser
|
||||
from timeit import default_timer as timer
|
||||
import io
|
||||
|
||||
|
||||
def chunk(
|
||||
filename,
|
||||
binary=None,
|
||||
from_page=0,
|
||||
to_page=100000,
|
||||
lang="Chinese",
|
||||
callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Only eml is supported
|
||||
"""
|
||||
eng = lang.lower() == "english" # is_english(cks)
|
||||
parser_config = kwargs.get(
|
||||
"parser_config",
|
||||
{"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"},
|
||||
)
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
|
||||
}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
main_res = []
|
||||
attachment_res = []
|
||||
|
||||
if binary:
|
||||
msg = BytesParser(policy=policy.default).parse(io.BytesIO(binary))
|
||||
else:
|
||||
msg = BytesParser(policy=policy.default).parse(open(filename, "rb"))
|
||||
|
||||
text_txt, html_txt = [], []
|
||||
# get the email header info
|
||||
for header, value in msg.items():
|
||||
text_txt.append(f"{header}: {value}")
|
||||
|
||||
# get the email main info
|
||||
def _add_content(msg, content_type):
|
||||
def _decode_payload(payload, charset, target_list):
|
||||
try:
|
||||
target_list.append(payload.decode(charset))
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
for enc in ["utf-8", "gb2312", "gbk", "gb18030", "latin1"]:
|
||||
try:
|
||||
target_list.append(payload.decode(enc))
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
target_list.append(payload.decode("utf-8", errors="ignore"))
|
||||
|
||||
if content_type == "text/plain":
|
||||
payload = msg.get_payload(decode=True)
|
||||
charset = msg.get_content_charset() or "utf-8"
|
||||
_decode_payload(payload, charset, text_txt)
|
||||
elif content_type == "text/html":
|
||||
payload = msg.get_payload(decode=True)
|
||||
charset = msg.get_content_charset() or "utf-8"
|
||||
_decode_payload(payload, charset, html_txt)
|
||||
elif "multipart" in content_type:
|
||||
if msg.is_multipart():
|
||||
for part in msg.iter_parts():
|
||||
_add_content(part, part.get_content_type())
|
||||
|
||||
_add_content(msg, msg.get_content_type())
|
||||
|
||||
sections = TxtParser.parser_txt("\n".join(text_txt)) + [
|
||||
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line
|
||||
]
|
||||
|
||||
st = timer()
|
||||
chunks = naive_merge(
|
||||
sections,
|
||||
int(parser_config.get("chunk_token_num", 128)),
|
||||
parser_config.get("delimiter", "\n!?。;!?"),
|
||||
)
|
||||
|
||||
main_res.extend(tokenize_chunks(chunks, doc, eng, None))
|
||||
logging.debug("naive_merge({}): {}".format(filename, timer() - st))
|
||||
# get the attachment info
|
||||
for part in msg.iter_attachments():
|
||||
content_disposition = part.get("Content-Disposition")
|
||||
if content_disposition:
|
||||
dispositions = content_disposition.strip().split(";")
|
||||
if dispositions[0].lower() == "attachment":
|
||||
filename = part.get_filename()
|
||||
payload = part.get_payload(decode=True)
|
||||
try:
|
||||
attachment_res.extend(
|
||||
naive_chunk(filename, payload, callback=callback, **kwargs)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return main_res + attachment_res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
299
api/app/core/rag/app/manual.py
Normal file
299
api/app/core/rag/app/manual.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import logging
|
||||
import copy
|
||||
import re
|
||||
|
||||
from app.core.rag.common.constants import ParserType
|
||||
from io import BytesIO
|
||||
from app.core.rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level
|
||||
from app.core.rag.common.token_utils import num_tokens_from_string
|
||||
from app.core.rag.deepdoc.parser import PdfParser, DocxParser
|
||||
from app.core.rag.deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
|
||||
from docx import Document
|
||||
from PIL import Image
|
||||
from .naive import by_plaintext, PARSERS
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
self.model_speciess = ParserType.MANUAL.value
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
from timeit import default_timer as timer
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(
|
||||
filename if not binary else binary,
|
||||
zoomin,
|
||||
from_page,
|
||||
to_page,
|
||||
callback
|
||||
)
|
||||
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
|
||||
logging.debug("OCR: {}".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin)
|
||||
callback(0.65, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
logging.debug("layouts: {}".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._table_transformer_job(zoomin)
|
||||
callback(0.67, "Table analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._text_merge()
|
||||
tbls = self._extract_table_figure(True, zoomin, True, True)
|
||||
self._concat_downward()
|
||||
self._filter_forpages()
|
||||
callback(0.68, "Text merged ({:.2f}s)".format(timer() - start))
|
||||
|
||||
# clean mess
|
||||
for b in self.boxes:
|
||||
b["text"] = re.sub(r"([\t ]|\u3000){2,}", " ", b["text"].strip())
|
||||
|
||||
return [(b["text"], b.get("layoutno", ""), self.get_position(b, zoomin))
|
||||
for i, b in enumerate(self.boxes)], tbls
|
||||
|
||||
|
||||
class Docx(DocxParser):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_picture(self, document, paragraph):
|
||||
img = paragraph._element.xpath('.//pic:pic')
|
||||
if not img:
|
||||
return None
|
||||
try:
|
||||
img = img[0]
|
||||
embed = img.xpath('.//a:blip/@r:embed')[0]
|
||||
related_part = document.part.related_parts[embed]
|
||||
image = related_part.image
|
||||
if image is not None:
|
||||
image = Image.open(BytesIO(image.blob))
|
||||
return image
|
||||
elif related_part.blob is not None:
|
||||
image = Image.open(BytesIO(related_part.blob))
|
||||
return image
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def concat_img(self, img1, img2):
|
||||
if img1 and not img2:
|
||||
return img1
|
||||
if not img1 and img2:
|
||||
return img2
|
||||
if not img1 and not img2:
|
||||
return None
|
||||
width1, height1 = img1.size
|
||||
width2, height2 = img2.size
|
||||
|
||||
new_width = max(width1, width2)
|
||||
new_height = height1 + height2
|
||||
new_image = Image.new('RGB', (new_width, new_height))
|
||||
|
||||
new_image.paste(img1, (0, 0))
|
||||
new_image.paste(img2, (0, height1))
|
||||
|
||||
return new_image
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None):
|
||||
self.doc = Document(
|
||||
filename) if not binary else Document(BytesIO(binary))
|
||||
pn = 0
|
||||
last_answer, last_image = "", None
|
||||
question_stack, level_stack = [], []
|
||||
ti_list = []
|
||||
for p in self.doc.paragraphs:
|
||||
if pn > to_page:
|
||||
break
|
||||
question_level, p_text = 0, ''
|
||||
if from_page <= pn < to_page and p.text.strip():
|
||||
question_level, p_text = docx_question_level(p)
|
||||
if not question_level or question_level > 6: # not a question
|
||||
last_answer = f'{last_answer}\n{p_text}'
|
||||
current_image = self.get_picture(self.doc, p)
|
||||
last_image = self.concat_img(last_image, current_image)
|
||||
else: # is a question
|
||||
if last_answer or last_image:
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
ti_list.append((f'{sum_question}\n{last_answer}', last_image))
|
||||
last_answer, last_image = '', None
|
||||
|
||||
i = question_level
|
||||
while question_stack and i <= level_stack[-1]:
|
||||
question_stack.pop()
|
||||
level_stack.pop()
|
||||
question_stack.append(p_text)
|
||||
level_stack.append(question_level)
|
||||
for run in p.runs:
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
continue
|
||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||
pn += 1
|
||||
if last_answer:
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
ti_list.append((f'{sum_question}\n{last_answer}', last_image))
|
||||
|
||||
tbls = []
|
||||
for tb in self.doc.tables:
|
||||
html= "<table>"
|
||||
for r in tb.rows:
|
||||
html += "<tr>"
|
||||
i = 0
|
||||
while i < len(r.cells):
|
||||
span = 1
|
||||
c = r.cells[i]
|
||||
for j in range(i+1, len(r.cells)):
|
||||
if c.text == r.cells[j].text:
|
||||
span += 1
|
||||
i = j
|
||||
else:
|
||||
break
|
||||
i += 1
|
||||
html += f"<td>{c.text}</td>" if span == 1 else f"<td colspan='{span}'>{c.text}</td>"
|
||||
html += "</tr>"
|
||||
html += "</table>"
|
||||
tbls.append(((None, html), ""))
|
||||
return ti_list, tbls
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Only pdf is supported.
|
||||
"""
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
pdf_parser = None
|
||||
doc = {
|
||||
"docnm_kwd": filename
|
||||
}
|
||||
doc["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"]))
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
# is it English
|
||||
eng = lang.lower() == "english" # pdf_parser.is_english
|
||||
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
|
||||
if isinstance(layout_recognizer, bool):
|
||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
||||
|
||||
name = layout_recognizer.strip().lower()
|
||||
pdf_parser = PARSERS.get(name, by_plaintext)
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
sections, tbls, pdf_parser = pdf_parser(
|
||||
filename = filename,
|
||||
binary = binary,
|
||||
from_page = from_page,
|
||||
to_page = to_page,
|
||||
lang = lang,
|
||||
callback = callback,
|
||||
pdf_cls = Pdf,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not sections and not tbls:
|
||||
return []
|
||||
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.03:
|
||||
max_lvl = max([lvl for _, lvl in pdf_parser.outlines])
|
||||
most_level = max(0, max_lvl - 1)
|
||||
levels = []
|
||||
for txt, _, _ in sections:
|
||||
for t, lvl in pdf_parser.outlines:
|
||||
tks = set([t[i] + t[i + 1] for i in range(len(t) - 1)])
|
||||
tks_ = set([txt[i] + txt[i + 1]
|
||||
for i in range(min(len(t), len(txt) - 1))])
|
||||
if len(set(tks & tks_)) / max([len(tks), len(tks_), 1]) > 0.8:
|
||||
levels.append(lvl)
|
||||
break
|
||||
else:
|
||||
levels.append(max_lvl + 1)
|
||||
|
||||
else:
|
||||
bull = bullets_category([txt for txt, _, _ in sections])
|
||||
most_level, levels = title_frequency(
|
||||
bull, [(txt, lvl) for txt, lvl, _ in sections])
|
||||
|
||||
assert len(sections) == len(levels)
|
||||
sec_ids = []
|
||||
sid = 0
|
||||
for i, lvl in enumerate(levels):
|
||||
if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
|
||||
sid += 1
|
||||
sec_ids.append(sid)
|
||||
|
||||
sections = [(txt, sec_ids[i], poss)
|
||||
for i, (txt, _, poss) in enumerate(sections)]
|
||||
for (img, rows), poss in tbls:
|
||||
if not rows:
|
||||
continue
|
||||
sections.append((rows if isinstance(rows, str) else rows[0], -1,
|
||||
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
||||
|
||||
def tag(pn, left, right, top, bottom):
|
||||
if pn + left + right + top + bottom == 0:
|
||||
return ""
|
||||
return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
|
||||
.format(pn, left, right, top, bottom)
|
||||
|
||||
chunks = []
|
||||
last_sid = -2
|
||||
tk_cnt = 0
|
||||
for txt, sec_id, poss in sorted(sections, key=lambda x: (
|
||||
x[-1][0][0], x[-1][0][3], x[-1][0][1])):
|
||||
poss = "\t".join([tag(*pos) for pos in poss])
|
||||
if tk_cnt < 32 or (tk_cnt < 1024 and (sec_id == last_sid or sec_id == -1)):
|
||||
if chunks:
|
||||
chunks[-1] += "\n" + txt + poss
|
||||
tk_cnt += num_tokens_from_string(txt)
|
||||
continue
|
||||
chunks.append(txt + poss)
|
||||
tk_cnt = num_tokens_from_string(txt)
|
||||
if sec_id > -1:
|
||||
last_sid = sec_id
|
||||
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
||||
return res
|
||||
|
||||
elif re.search(r"\.docx?$", filename, re.IGNORECASE):
|
||||
docx_parser = Docx()
|
||||
ti_list, tbls = docx_parser(filename, binary,
|
||||
from_page=0, to_page=10000, callback=callback)
|
||||
tbls=vision_figure_parser_docx_wrapper(sections=ti_list,tbls=tbls,callback=callback,**kwargs)
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
for text, image in ti_list:
|
||||
d = copy.deepcopy(doc)
|
||||
if image:
|
||||
d['image'] = image
|
||||
d["doc_type_kwd"] = "image"
|
||||
tokenize(d, text, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
else:
|
||||
raise NotImplementedError("file type not supported yet(pdf and docx supported)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
849
api/app/core/rag/app/naive.py
Normal file
849
api/app/core/rag/app/naive.py
Normal file
@@ -0,0 +1,849 @@
|
||||
import logging
|
||||
import re
|
||||
import os
|
||||
from functools import reduce
|
||||
from io import BytesIO
|
||||
from timeit import default_timer as timer
|
||||
from docx import Document
|
||||
from docx.image.exceptions import InvalidImageStreamError, UnexpectedEndOfFileError, UnrecognizedImageError
|
||||
from docx.opc.pkgreader import _SerializedRelationships, _SerializedRelationship
|
||||
from docx.opc.oxml import parse_xml
|
||||
from markdown import markdown
|
||||
from PIL import Image
|
||||
import copy
|
||||
|
||||
from app.core.rag.llm.cv_model import AzureGptV4, QWenCV
|
||||
from app.core.rag.common.file_utils import get_project_base_directory
|
||||
from app.core.rag.utils.file_utils import extract_embed_file, extract_links_from_pdf, extract_links_from_docx, extract_html
|
||||
from app.core.rag.deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser
|
||||
from app.core.rag.deepdoc.parser.figure_parser import VisionFigureParser,vision_figure_parser_docx_wrapper,vision_figure_parser_pdf_wrapper
|
||||
from app.core.rag.deepdoc.parser.pdf_parser import PlainParser, VisionParser
|
||||
from app.core.rag.deepdoc.parser.mineru_parser import MinerUParser
|
||||
from app.core.rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, tokenize, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table
|
||||
|
||||
def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, vision_model=None, pdf_cls = None ,**kwargs):
|
||||
callback = callback
|
||||
binary = binary
|
||||
pdf_parser = pdf_cls() if pdf_cls else Pdf()
|
||||
sections, tables = pdf_parser(
|
||||
filename if not binary else binary,
|
||||
from_page=from_page,
|
||||
to_page=to_page,
|
||||
callback=callback
|
||||
)
|
||||
|
||||
tables = vision_figure_parser_pdf_wrapper(tbls=tables,
|
||||
callback=callback,
|
||||
vision_model=vision_model,
|
||||
**kwargs)
|
||||
return sections, tables, pdf_parser
|
||||
|
||||
|
||||
def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, vision_model=None, pdf_cls = None ,**kwargs):
|
||||
mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru")
|
||||
mineru_api = os.environ.get("MINERU_APISERVER", "http://host.docker.internal:9987")
|
||||
pdf_parser = MinerUParser(mineru_path=mineru_executable, mineru_api=mineru_api)
|
||||
|
||||
if not pdf_parser.check_installation():
|
||||
callback(-1, "MinerU not found.")
|
||||
return None, None, pdf_parser
|
||||
|
||||
sections, tables = pdf_parser.parse_pdf(
|
||||
filepath=filename,
|
||||
binary=binary,
|
||||
callback=callback,
|
||||
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
|
||||
backend=os.environ.get("MINERU_BACKEND", "pipeline"),
|
||||
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
|
||||
)
|
||||
return sections, tables, pdf_parser
|
||||
|
||||
|
||||
def by_textln(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, vision_model=None, pdf_cls = None ,**kwargs):
|
||||
textln_app_id = os.environ.get("TEXTLN_APP_ID", "")
|
||||
textln_secret_code = os.environ.get("TEXTLN_SECRET_CODE", "")
|
||||
textln_api = os.environ.get("TEXTLN_APISERVER", "https://api.textin.com/ai/service/v1/pdf_to_markdown")
|
||||
pdf_parser = MinerUParser(mineru_path=textln_app_id, mineru_api=textln_api)
|
||||
|
||||
if not pdf_parser.check_installation():
|
||||
callback(-1, "MinerU not found.")
|
||||
return None, None, pdf_parser
|
||||
|
||||
sections, tables = pdf_parser.parse_pdf(
|
||||
filepath=filename,
|
||||
binary=binary,
|
||||
callback=callback,
|
||||
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
|
||||
backend=os.environ.get("MINERU_BACKEND", "pipeline"),
|
||||
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
|
||||
)
|
||||
return sections, tables, pdf_parser
|
||||
|
||||
|
||||
def by_plaintext(filename, binary=None, from_page=0, to_page=100000, callback=None, vision_model=None, **kwargs):
|
||||
if kwargs.get("layout_recognizer", "") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
else:
|
||||
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
|
||||
|
||||
sections, tables = pdf_parser(
|
||||
filename if not binary else binary,
|
||||
from_page=from_page,
|
||||
to_page=to_page,
|
||||
callback=callback
|
||||
)
|
||||
return sections, tables, pdf_parser
|
||||
|
||||
|
||||
PARSERS = {
|
||||
"deepdoc": by_deepdoc,
|
||||
"mineru": by_mineru,
|
||||
"textln": by_textln,
|
||||
"plaintext": by_plaintext, # default
|
||||
}
|
||||
|
||||
|
||||
class Docx(DocxParser):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_picture(self, document, paragraph):
|
||||
imgs = paragraph._element.xpath('.//pic:pic')
|
||||
if not imgs:
|
||||
return None
|
||||
res_img = None
|
||||
for img in imgs:
|
||||
embed = img.xpath('.//a:blip/@r:embed')
|
||||
if not embed:
|
||||
continue
|
||||
embed = embed[0]
|
||||
try:
|
||||
related_part = document.part.related_parts[embed]
|
||||
image_blob = related_part.image.blob
|
||||
except UnrecognizedImageError:
|
||||
logging.info("Unrecognized image format. Skipping image.")
|
||||
continue
|
||||
except UnexpectedEndOfFileError:
|
||||
logging.info("EOF was unexpectedly encountered while reading an image stream. Skipping image.")
|
||||
continue
|
||||
except InvalidImageStreamError:
|
||||
logging.info("The recognized image stream appears to be corrupted. Skipping image.")
|
||||
continue
|
||||
except UnicodeDecodeError:
|
||||
logging.info("The recognized image stream appears to be corrupted. Skipping image.")
|
||||
continue
|
||||
except Exception:
|
||||
logging.info("The recognized image stream appears to be corrupted. Skipping image.")
|
||||
continue
|
||||
try:
|
||||
image = Image.open(BytesIO(image_blob)).convert('RGB')
|
||||
if res_img is None:
|
||||
res_img = image
|
||||
else:
|
||||
res_img = concat_img(res_img, image)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return res_img
|
||||
|
||||
def __clean(self, line):
|
||||
line = re.sub(r"\u3000", " ", line).strip()
|
||||
return line
|
||||
|
||||
def __get_nearest_title(self, table_index, filename):
|
||||
"""Get the hierarchical title structure before the table"""
|
||||
import re
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
titles = []
|
||||
blocks = []
|
||||
|
||||
# Get document name from filename parameter
|
||||
doc_name = re.sub(r"\.[a-zA-Z]+$", "", filename)
|
||||
if not doc_name:
|
||||
doc_name = "Untitled Document"
|
||||
|
||||
# Collect all document blocks while maintaining document order
|
||||
try:
|
||||
# Iterate through all paragraphs and tables in document order
|
||||
for i, block in enumerate(self.doc._element.body):
|
||||
if block.tag.endswith('p'): # Paragraph
|
||||
p = Paragraph(block, self.doc)
|
||||
blocks.append(('p', i, p))
|
||||
elif block.tag.endswith('tbl'): # Table
|
||||
blocks.append(('t', i, None)) # Table object will be retrieved later
|
||||
except Exception as e:
|
||||
logging.error(f"Error collecting blocks: {e}")
|
||||
return ""
|
||||
|
||||
# Find the target table position
|
||||
target_table_pos = -1
|
||||
table_count = 0
|
||||
for i, (block_type, pos, _) in enumerate(blocks):
|
||||
if block_type == 't':
|
||||
if table_count == table_index:
|
||||
target_table_pos = pos
|
||||
break
|
||||
table_count += 1
|
||||
|
||||
if target_table_pos == -1:
|
||||
return "" # Target table not found
|
||||
|
||||
# Find the nearest heading paragraph in reverse order
|
||||
nearest_title = None
|
||||
for i in range(len(blocks)-1, -1, -1):
|
||||
block_type, pos, block = blocks[i]
|
||||
if pos >= target_table_pos: # Skip blocks after the table
|
||||
continue
|
||||
|
||||
if block_type != 'p':
|
||||
continue
|
||||
|
||||
if block.style and block.style.name and re.search(r"Heading\s*(\d+)", block.style.name, re.I):
|
||||
try:
|
||||
level_match = re.search(r"(\d+)", block.style.name)
|
||||
if level_match:
|
||||
level = int(level_match.group(1))
|
||||
if level <= 7: # Support up to 7 heading levels
|
||||
title_text = block.text.strip()
|
||||
if title_text: # Avoid empty titles
|
||||
nearest_title = (level, title_text)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error(f"Error parsing heading level: {e}")
|
||||
|
||||
if nearest_title:
|
||||
# Add current title
|
||||
titles.append(nearest_title)
|
||||
current_level = nearest_title[0]
|
||||
|
||||
# Find all parent headings, allowing cross-level search
|
||||
while current_level > 1:
|
||||
found = False
|
||||
for i in range(len(blocks)-1, -1, -1):
|
||||
block_type, pos, block = blocks[i]
|
||||
if pos >= target_table_pos: # Skip blocks after the table
|
||||
continue
|
||||
|
||||
if block_type != 'p':
|
||||
continue
|
||||
|
||||
if block.style and re.search(r"Heading\s*(\d+)", block.style.name, re.I):
|
||||
try:
|
||||
level_match = re.search(r"(\d+)", block.style.name)
|
||||
if level_match:
|
||||
level = int(level_match.group(1))
|
||||
# Find any heading with a higher level
|
||||
if level < current_level:
|
||||
title_text = block.text.strip()
|
||||
if title_text: # Avoid empty titles
|
||||
titles.append((level, title_text))
|
||||
current_level = level
|
||||
found = True
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error(f"Error parsing parent heading: {e}")
|
||||
|
||||
if not found: # Break if no parent heading is found
|
||||
break
|
||||
|
||||
# Sort by level (ascending, from highest to lowest)
|
||||
titles.sort(key=lambda x: x[0])
|
||||
# Organize titles (from highest to lowest)
|
||||
hierarchy = [doc_name] + [t[1] for t in titles]
|
||||
return " > ".join(hierarchy)
|
||||
|
||||
return ""
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
|
||||
self.doc = Document(
|
||||
filename) if not binary else Document(BytesIO(binary))
|
||||
pn = 0
|
||||
lines = []
|
||||
last_image = None
|
||||
for p in self.doc.paragraphs:
|
||||
if pn > to_page:
|
||||
break
|
||||
if from_page <= pn < to_page:
|
||||
if p.text.strip():
|
||||
if p.style and p.style.name == 'Caption':
|
||||
former_image = None
|
||||
if lines and lines[-1][1] and lines[-1][2] != 'Caption':
|
||||
former_image = lines[-1][1].pop()
|
||||
elif last_image:
|
||||
former_image = last_image
|
||||
last_image = None
|
||||
lines.append((self.__clean(p.text), [former_image], p.style.name))
|
||||
else:
|
||||
current_image = self.get_picture(self.doc, p)
|
||||
image_list = [current_image]
|
||||
if last_image:
|
||||
image_list.insert(0, last_image)
|
||||
last_image = None
|
||||
lines.append((self.__clean(p.text), image_list, p.style.name if p.style else ""))
|
||||
else:
|
||||
if current_image := self.get_picture(self.doc, p):
|
||||
if lines:
|
||||
lines[-1][1].append(current_image)
|
||||
else:
|
||||
last_image = current_image
|
||||
for run in p.runs:
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
continue
|
||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||
pn += 1
|
||||
new_line = [(line[0], reduce(concat_img, line[1]) if line[1] else None) for line in lines]
|
||||
|
||||
tbls = []
|
||||
for i, tb in enumerate(self.doc.tables):
|
||||
title = self.__get_nearest_title(i, filename)
|
||||
html = "<table>"
|
||||
if title:
|
||||
html += f"<caption>Table Location: {title}</caption>"
|
||||
for r in tb.rows:
|
||||
html += "<tr>"
|
||||
i = 0
|
||||
try:
|
||||
while i < len(r.cells):
|
||||
span = 1
|
||||
c = r.cells[i]
|
||||
for j in range(i + 1, len(r.cells)):
|
||||
if c.text == r.cells[j].text:
|
||||
span += 1
|
||||
i = j
|
||||
else:
|
||||
break
|
||||
i += 1
|
||||
html += f"<td>{c.text}</td>" if span == 1 else f"<td colspan='{span}'>{c.text}</td>"
|
||||
except Exception as e:
|
||||
logging.warning(f"Error parsing table, ignore: {e}")
|
||||
html += "</tr>"
|
||||
html += "</table>"
|
||||
tbls.append(((None, html), ""))
|
||||
return new_line, tbls
|
||||
|
||||
def to_markdown(self, filename=None, binary=None, inline_images: bool = True):
|
||||
"""
|
||||
This function uses mammoth, licensed under the BSD 2-Clause License.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
import mammoth
|
||||
from markdownify import markdownify
|
||||
|
||||
docx_file = BytesIO(binary) if binary else open(filename, "rb")
|
||||
|
||||
def _convert_image_to_base64(image):
|
||||
try:
|
||||
with image.open() as image_file:
|
||||
image_bytes = image_file.read()
|
||||
encoded = base64.b64encode(image_bytes).decode("utf-8")
|
||||
base64_url = f"data:{image.content_type};base64,{encoded}"
|
||||
|
||||
alt_name = "image"
|
||||
alt_name = f"img_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
return {"src": base64_url, "alt": alt_name}
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to convert image to base64: {e}")
|
||||
return {"src": "", "alt": "image"}
|
||||
|
||||
try:
|
||||
if inline_images:
|
||||
result = mammoth.convert_to_html(docx_file, convert_image=mammoth.images.img_element(_convert_image_to_base64))
|
||||
else:
|
||||
result = mammoth.convert_to_html(docx_file)
|
||||
|
||||
html = result.value
|
||||
|
||||
markdown_text = markdownify(html)
|
||||
return markdown_text
|
||||
|
||||
finally:
|
||||
if not binary:
|
||||
docx_file.close()
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None, separate_tables_figures=False):
|
||||
start = timer()
|
||||
first_start = start
|
||||
callback(msg="OCR started")
|
||||
self.__images__(
|
||||
filename if not binary else binary,
|
||||
zoomin,
|
||||
from_page,
|
||||
to_page,
|
||||
callback
|
||||
)
|
||||
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
|
||||
logging.info("OCR({}~{}): {:.2f}s".format(from_page, to_page, timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin)
|
||||
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._table_transformer_job(zoomin)
|
||||
callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._text_merge(zoomin=zoomin)
|
||||
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
|
||||
|
||||
if separate_tables_figures:
|
||||
tbls, figures = self._extract_table_figure(True, zoomin, True, True, True)
|
||||
self._concat_downward()
|
||||
logging.info("layouts cost: {}s".format(timer() - first_start))
|
||||
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls, figures
|
||||
else:
|
||||
tbls = self._extract_table_figure(True, zoomin, True, True)
|
||||
self._naive_vertical_merge()
|
||||
self._concat_downward()
|
||||
self._final_reading_order_merge()
|
||||
# self._filter_forpages()
|
||||
logging.info("layouts cost: {}s".format(timer() - first_start))
|
||||
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls
|
||||
|
||||
|
||||
class Markdown(MarkdownParser):
|
||||
def md_to_html(self, sections):
|
||||
if not sections:
|
||||
return []
|
||||
if isinstance(sections, type("")):
|
||||
text = sections
|
||||
elif isinstance(sections[0], type("")):
|
||||
text = sections[0]
|
||||
else:
|
||||
return []
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
html_content = markdown(text)
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
return soup
|
||||
|
||||
def get_picture_urls(self, soup):
|
||||
if soup:
|
||||
return [img.get('src') for img in soup.find_all('img') if img.get('src')]
|
||||
return []
|
||||
|
||||
def get_hyperlink_urls(self, soup):
|
||||
if soup:
|
||||
return set([a.get('href') for a in soup.find_all('a') if a.get('href')])
|
||||
return []
|
||||
|
||||
def get_pictures(self, text):
|
||||
"""Download and open all images from markdown text."""
|
||||
import requests
|
||||
soup = self.md_to_html(text)
|
||||
image_urls = self.get_picture_urls(soup)
|
||||
images = []
|
||||
# Find all image URLs in text
|
||||
for url in image_urls:
|
||||
if not url:
|
||||
continue
|
||||
try:
|
||||
# check if the url is a local file or a remote URL
|
||||
if url.startswith(('http://', 'https://')):
|
||||
# For remote URLs, download the image
|
||||
response = requests.get(url, stream=True, timeout=30)
|
||||
if response.status_code == 200 and response.headers['Content-Type'] and response.headers['Content-Type'].startswith('image/'):
|
||||
img = Image.open(BytesIO(response.content)).convert('RGB')
|
||||
images.append(img)
|
||||
else:
|
||||
# For local file paths, open the image directly
|
||||
from pathlib import Path
|
||||
local_path = Path(url)
|
||||
if not local_path.exists():
|
||||
logging.warning(f"Local image file not found: {url}")
|
||||
continue
|
||||
img = Image.open(url).convert('RGB')
|
||||
images.append(img)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to download/open image from {url}: {e}")
|
||||
continue
|
||||
|
||||
return images if images else None
|
||||
|
||||
def __call__(self, filename, binary=None, separate_tables=True,delimiter=None):
|
||||
if binary:
|
||||
encoding = find_codec(binary)
|
||||
txt = binary.decode(encoding, errors="ignore")
|
||||
else:
|
||||
with open(filename, "r") as f:
|
||||
txt = f.read()
|
||||
|
||||
remainder, tables = self.extract_tables_and_remainder(f'{txt}\n', separate_tables=separate_tables)
|
||||
# To eliminate duplicate tables in chunking result, uncomment code below and set separate_tables to True in line 410.
|
||||
# extractor = MarkdownElementExtractor(remainder)
|
||||
extractor = MarkdownElementExtractor(txt)
|
||||
element_sections = extractor.extract_elements(delimiter)
|
||||
sections = [(element, "") for element in element_sections]
|
||||
tbls = []
|
||||
for table in tables:
|
||||
tbls.append(((None, markdown(table, extensions=['markdown.extensions.tables'])), ""))
|
||||
return sections, tbls
|
||||
|
||||
def load_from_xml_v2(baseURI, rels_item_xml):
|
||||
"""
|
||||
Return |_SerializedRelationships| instance loaded with the
|
||||
relationships contained in *rels_item_xml*. Returns an empty
|
||||
collection if *rels_item_xml* is |None|.
|
||||
"""
|
||||
srels = _SerializedRelationships()
|
||||
if rels_item_xml is not None:
|
||||
rels_elm = parse_xml(rels_item_xml)
|
||||
for rel_elm in rels_elm.Relationship_lst:
|
||||
if rel_elm.target_ref in ('../NULL', 'NULL'):
|
||||
continue
|
||||
srels._srels.append(_SerializedRelationship(baseURI, rel_elm))
|
||||
return srels
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, vision_model=None, **kwargs):
|
||||
"""
|
||||
Supported file formats are docx, doc, pdf, excel, txt, markdown, html, json.
|
||||
This method apply the naive ways to chunk files.
|
||||
Successive text will be sliced into pieces using 'delimiter'.
|
||||
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
|
||||
"""
|
||||
urls = set()
|
||||
url_res = []
|
||||
|
||||
|
||||
is_english = lang.lower() == "english" # is_english(cks)
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"layout_recognize": "DeepDOC", "chunk_token_num": 512, "delimiter": "\n!?。;!?", "analyze_hyperlink": True})
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
res = []
|
||||
pdf_parser = None
|
||||
section_images = None
|
||||
|
||||
is_root = kwargs.get("is_root", True)
|
||||
embed_res = []
|
||||
if is_root:
|
||||
# Only extract embedded files at the root call
|
||||
embeds = []
|
||||
if binary is not None:
|
||||
embeds = extract_embed_file(binary)
|
||||
else:
|
||||
raise Exception("Embedding extraction from file path is not supported.")
|
||||
|
||||
# Recursively chunk each embedded file and collect results
|
||||
for embed_filename, embed_bytes in embeds:
|
||||
try:
|
||||
sub_res = chunk(embed_filename, binary=embed_bytes, lang=lang, callback=callback, vision_model=vision_model, is_root=False, **kwargs) or []
|
||||
embed_res.extend(sub_res)
|
||||
except Exception as e:
|
||||
if callback:
|
||||
callback(0.05, f"Failed to chunk embed {embed_filename}: {e}")
|
||||
continue
|
||||
|
||||
if re.search(r"\.docx$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
if parser_config.get("analyze_hyperlink", False) and is_root:
|
||||
urls = extract_links_from_docx(binary)
|
||||
for index, url in enumerate(urls):
|
||||
html_bytes, metadata = extract_html(url)
|
||||
if not html_bytes:
|
||||
continue
|
||||
try:
|
||||
sub_url_res = chunk(url, html_bytes, lang=lang, callback=callback, vision_model=vision_model, is_root=False, **kwargs)
|
||||
except Exception as e:
|
||||
logging.info(f"Failed to chunk url in registered file type {url}: {e}")
|
||||
sub_url_res = chunk(f"{index}.html", html_bytes, lang=lang, callback=callback, vision_model=vision_model, is_root=False, **kwargs)
|
||||
url_res.extend(sub_url_res)
|
||||
|
||||
# fix "There is no item named 'word/NULL' in the archive", referring to https://github.com/python-openxml/python-docx/issues/1105#issuecomment-1298075246
|
||||
_SerializedRelationships.load_from_xml = load_from_xml_v2
|
||||
sections, tables = Docx()(filename, binary)
|
||||
|
||||
tables=vision_figure_parser_docx_wrapper(sections=sections,tbls=tables,callback=callback, vision_model=vision_model, **kwargs)
|
||||
|
||||
res = tokenize_table(tables, doc, is_english)
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
st = timer()
|
||||
|
||||
chunks, images = naive_merge_docx(
|
||||
sections, int(parser_config.get(
|
||||
"chunk_token_num", 128)), parser_config.get(
|
||||
"delimiter", "\n!?。;!?"))
|
||||
|
||||
if kwargs.get("section_only", False):
|
||||
chunks.extend(embed_res)
|
||||
chunks.extend(url_res)
|
||||
return chunks
|
||||
|
||||
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images))
|
||||
logging.info("naive_merge({}): {}".format(filename, timer() - st))
|
||||
res.extend(embed_res)
|
||||
res.extend(url_res)
|
||||
return res
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
if parser_config.get("analyze_hyperlink", False) and is_root:
|
||||
urls = extract_links_from_pdf(binary)
|
||||
|
||||
if isinstance(layout_recognizer, bool):
|
||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
||||
|
||||
name = layout_recognizer.strip().lower()
|
||||
parser = PARSERS.get(name, by_plaintext)
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
sections, tables, pdf_parser = parser(
|
||||
filename=filename,
|
||||
binary=binary,
|
||||
from_page=from_page,
|
||||
to_page=to_page,
|
||||
lang=lang,
|
||||
callback=callback,
|
||||
vision_model=vision_model,
|
||||
layout_recognizer=layout_recognizer,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not sections and not tables:
|
||||
return []
|
||||
|
||||
if name in ["mineru", "textln"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
res = tokenize_table(tables, doc, is_english)
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.pptx?$", filename, re.IGNORECASE):
|
||||
if not binary:
|
||||
with open(filename, "rb") as f:
|
||||
binary = f.read()
|
||||
from app.core.rag.app.presentation import Ppt
|
||||
ppt_parser = Ppt()
|
||||
for pn, (txt, img) in enumerate(ppt_parser(
|
||||
filename if not binary else binary, from_page, to_page, callback)):
|
||||
d = copy.deepcopy(doc)
|
||||
pn += from_page
|
||||
d["image"] = img
|
||||
d["doc_type_kwd"] = "image"
|
||||
d["page_num_int"] = [pn + 1]
|
||||
d["top_int"] = [0]
|
||||
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
|
||||
tokenize(d, txt, is_english)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
elif re.search(r"\.(da|wave|wav|mp3|aac|flac|ogg|aiff|au|midi|wma|realaudio|vqf|oggvorbis|ape?)$", filename, re.IGNORECASE):
|
||||
if not binary:
|
||||
with open(filename, "rb") as f:
|
||||
binary = f.read()
|
||||
from app.core.rag.app.audio import chunk as parser
|
||||
return parser(filename, binary, lang=lang, callback=callback, seq2txt_mdl=vision_model, **kwargs)
|
||||
|
||||
elif re.search(r"\.(png|jpeg|jpg|gif|bmp|svg|mp4|mov|avi|flv|mpeg|mpg|webm|wmv|3gp|3gpp|mkv?)$", filename, re.IGNORECASE):
|
||||
if not binary:
|
||||
with open(filename, "rb") as f:
|
||||
binary = f.read()
|
||||
from app.core.rag.app.picture import chunk as parser
|
||||
return parser(filename, binary, lang=lang, callback=callback, vision_model=vision_model, **kwargs)
|
||||
|
||||
elif re.search(r"\.(csv|xlsx?)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
if not binary:
|
||||
with open(filename, "rb") as f:
|
||||
binary = f.read()
|
||||
excel_parser = ExcelParser()
|
||||
if parser_config.get("html4excel"):
|
||||
sections = [(_, "") for _ in excel_parser.html(binary, 12) if _]
|
||||
else:
|
||||
sections = [(_, "") for _ in excel_parser(binary) if _]
|
||||
parser_config["chunk_token_num"] = 12800
|
||||
|
||||
elif re.search(r"\.(txt|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|sql)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
sections = TxtParser()(filename, binary,
|
||||
parser_config.get("chunk_token_num", 128),
|
||||
parser_config.get("delimiter", "\n!?;。;!?"))
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.(md|markdown)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
markdown_parser = Markdown(int(parser_config.get("chunk_token_num", 128)))
|
||||
sections, tables = markdown_parser(filename, binary, separate_tables=False,delimiter=parser_config.get("delimiter", "\n!?;。;!?"))
|
||||
|
||||
if vision_model:
|
||||
# Process images for each section
|
||||
section_images = []
|
||||
for idx, (section_text, _) in enumerate(sections):
|
||||
images = markdown_parser.get_pictures(section_text) if section_text else None
|
||||
|
||||
if images:
|
||||
# If multiple images found, combine them using concat_img
|
||||
combined_image = reduce(concat_img, images) if len(images) > 1 else images[0]
|
||||
section_images.append(combined_image)
|
||||
markdown_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data= [((combined_image, ["markdown image"]), [(0, 0, 0, 0, 0)])], **kwargs)
|
||||
boosted_figures = markdown_vision_parser(callback=callback)
|
||||
sections[idx] = (section_text + "\n\n" + "\n\n".join([fig[0][1][0] for fig in boosted_figures]), sections[idx][1])
|
||||
else:
|
||||
section_images.append(None)
|
||||
|
||||
else:
|
||||
logging.warning("No visual model detected. Skipping figure parsing enhancement.")
|
||||
|
||||
if parser_config.get("hyperlink_urls", False) and is_root:
|
||||
for idx, (section_text, _) in enumerate(sections):
|
||||
soup = markdown_parser.md_to_html(section_text)
|
||||
hyperlink_urls = markdown_parser.get_hyperlink_urls(soup)
|
||||
urls.update(hyperlink_urls)
|
||||
res = tokenize_table(tables, doc, is_english)
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
chunk_token_num = int(parser_config.get("chunk_token_num", 128))
|
||||
sections = HtmlParser()(filename, binary, chunk_token_num)
|
||||
sections = [(_, "") for _ in sections if _]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.(json|jsonl|ldjson)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
chunk_token_num = int(parser_config.get("chunk_token_num", 128))
|
||||
sections = JsonParser(chunk_token_num)(filename)
|
||||
sections = [(_, "") for _ in sections if _]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
try:
|
||||
import tika
|
||||
os.environ['TIKA_SERVER_JAR'] = "/tmp/tika-server.jar"
|
||||
os.environ['TIKA_SERVER_PORT'] = '9998'
|
||||
# java11 Initialize Tika 3.1.0.jar service url:http://localhost:9998 view process:lsof -i :9998
|
||||
tika.initVM()
|
||||
from tika import parser as tika_parser
|
||||
except Exception as e:
|
||||
callback(0.8, f"tika not available: {e}. Unsupported .doc parsing.")
|
||||
logging.warning(f"tika not available: {e}. Unsupported .doc parsing for {filename}.")
|
||||
return []
|
||||
|
||||
doc_parsed = tika_parser.from_file(filename)
|
||||
if doc_parsed.get('content', None) is not None:
|
||||
sections = doc_parsed['content'].split('\n')
|
||||
sections = [(_, "") for _ in sections if _]
|
||||
callback(0.8, "Finish parsing.")
|
||||
else:
|
||||
callback(0.8, f"tika.parser got empty content from {filename}.")
|
||||
logging.warning(f"tika.parser got empty content from {filename}.")
|
||||
return []
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"file type not supported yet(pdf, xlsx, doc, docx, txt supported)")
|
||||
|
||||
st = timer()
|
||||
if section_images:
|
||||
# if all images are None, set section_images to None
|
||||
if all(image is None for image in section_images):
|
||||
section_images = None
|
||||
|
||||
if section_images:
|
||||
chunks, images = naive_merge_with_images(sections, section_images,
|
||||
int(parser_config.get(
|
||||
"chunk_token_num", 128)), parser_config.get(
|
||||
"delimiter", "\n!?。;!?"))
|
||||
if kwargs.get("section_only", False):
|
||||
chunks.extend(embed_res)
|
||||
return chunks
|
||||
|
||||
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images))
|
||||
else:
|
||||
chunks = naive_merge(
|
||||
sections, int(parser_config.get(
|
||||
"chunk_token_num", 128)), parser_config.get(
|
||||
"delimiter", "\n!?。;!?"))
|
||||
if kwargs.get("section_only", False):
|
||||
chunks.extend(embed_res)
|
||||
return chunks
|
||||
|
||||
res.extend(tokenize_chunks(chunks, doc, is_english, pdf_parser))
|
||||
|
||||
if urls and parser_config.get("analyze_hyperlink", False) and is_root:
|
||||
for index, url in enumerate(urls):
|
||||
html_bytes, metadata = extract_html(url)
|
||||
if not html_bytes:
|
||||
continue
|
||||
try:
|
||||
sub_url_res = chunk(url, html_bytes, callback=callback, lang=lang, is_root=False, **kwargs)
|
||||
except Exception as e:
|
||||
logging.info(f"Failed to chunk url in registered file type {url}: {e}")
|
||||
sub_url_res = chunk(f"{index}.html", html_bytes, lang=lang, callback=callback, vision_model=vision_model, is_root=False, **kwargs)
|
||||
url_res.extend(sub_url_res)
|
||||
|
||||
logging.info("naive_merge({}): {}".format(filename, timer() - st))
|
||||
|
||||
if embed_res:
|
||||
res.extend(embed_res)
|
||||
if url_res:
|
||||
res.extend(url_res)
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# import sys
|
||||
# chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
|
||||
# Prepare to configure vision_model information
|
||||
vision_model = QWenCV(
|
||||
key="sk-8e9e40cd171749858ce2d3722ea75669",
|
||||
model_name="qwen-vl-max",
|
||||
lang="chinese", # 默认使用中文
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
)
|
||||
|
||||
def progress_callback(prog=None, msg=None):
|
||||
print(f"prog: {prog} msg: {msg}\n")
|
||||
|
||||
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/1.txt"
|
||||
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/2.md"
|
||||
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/3.md" # 带图url
|
||||
file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/义务教育教科书·中国历史七年级上册 (2)_Compressed.md"
|
||||
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/4.doc"
|
||||
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/5.json"
|
||||
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/6.html"
|
||||
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/7.xlsx"
|
||||
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/8.pdf"
|
||||
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/9.pptx"
|
||||
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/10.png"
|
||||
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/11.mp4"
|
||||
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/12.mp3"
|
||||
res = chunk(filename=file_path,
|
||||
from_page=0,
|
||||
to_page=10,
|
||||
callback=progress_callback,
|
||||
vision_model=vision_model,
|
||||
parser_config={
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": "\n",
|
||||
"analyze_hyperlink": True,
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"html4excel": "false"
|
||||
},
|
||||
is_root=False)
|
||||
for index, item in enumerate(res):
|
||||
print(f"Index: {index}\n----")
|
||||
print(item)
|
||||
print("----")
|
||||
149
api/app/core/rag/app/one.py
Normal file
149
api/app/core/rag/app/one.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import logging
|
||||
from io import BytesIO
|
||||
import re
|
||||
|
||||
from app.core.rag.deepdoc.parser.utils import get_text
|
||||
from . import naive
|
||||
from app.core.rag.nlp import rag_tokenizer, tokenize
|
||||
from app.core.rag.deepdoc.parser import PdfParser, ExcelParser, HtmlParser
|
||||
from app.core.rag.deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
|
||||
from app.core.rag.app.naive import by_plaintext, PARSERS
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
from timeit import default_timer as timer
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(
|
||||
filename if not binary else binary,
|
||||
zoomin,
|
||||
from_page,
|
||||
to_page,
|
||||
callback
|
||||
)
|
||||
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin, drop=False)
|
||||
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
logging.debug("layouts cost: {}s".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._table_transformer_job(zoomin)
|
||||
callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._text_merge()
|
||||
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
|
||||
tbls = self._extract_table_figure(True, zoomin, True, True)
|
||||
self._concat_downward()
|
||||
|
||||
sections = [(b["text"], self.get_position(b, zoomin))
|
||||
for i, b in enumerate(self.boxes)]
|
||||
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
|
||||
x[-1][0][0], x[-1][0][3], x[-1][0][1]))], tbls
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Supported file formats are docx, pdf, excel, txt.
|
||||
One file forms a chunk which maintains original text order.
|
||||
"""
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
eng = lang.lower() == "english" # is_english(cks)
|
||||
|
||||
if re.search(r"\.docx$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
sections, tbls = naive.Docx()(filename, binary)
|
||||
tbls=vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
|
||||
sections = [s for s, _ in sections if s]
|
||||
for (_, html), _ in tbls:
|
||||
sections.append(html)
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
|
||||
if isinstance(layout_recognizer, bool):
|
||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
||||
|
||||
name = layout_recognizer.strip().lower()
|
||||
parser = PARSERS.get(name, by_plaintext)
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
sections, tbls, pdf_parser = parser(
|
||||
filename = filename,
|
||||
binary = binary,
|
||||
from_page = from_page,
|
||||
to_page = to_page,
|
||||
lang = lang,
|
||||
callback = callback,
|
||||
pdf_cls = Pdf,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not sections and not tbls:
|
||||
return []
|
||||
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
for (img, rows), poss in tbls:
|
||||
if not rows:
|
||||
continue
|
||||
sections.append((rows if isinstance(rows, str) else rows[0],
|
||||
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
||||
sections = [s for s, _ in sections if s]
|
||||
|
||||
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
excel_parser = ExcelParser()
|
||||
sections = excel_parser.html(binary, 1000000000)
|
||||
|
||||
elif re.search(r"\.(txt|md|markdown)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
sections = txt.split("\n")
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
sections = HtmlParser()(filename, binary)
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
binary = BytesIO(binary)
|
||||
doc_parsed = parser.from_buffer(binary)
|
||||
sections = doc_parsed['content'].split('\n')
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"file type not supported yet(doc, docx, pdf, txt supported)")
|
||||
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
tokenize(doc, "\n".join(sections), eng)
|
||||
return [doc]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
284
api/app/core/rag/app/paper.py
Normal file
284
api/app/core/rag/app/paper.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import logging
|
||||
import copy
|
||||
import re
|
||||
|
||||
from app.core.rag.deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper
|
||||
from app.core.rag.common.constants import ParserType
|
||||
from app.core.rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks
|
||||
from app.core.rag.deepdoc.parser import PdfParser, PlainParser
|
||||
import numpy as np
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
self.model_speciess = ParserType.PAPER.value
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
from timeit import default_timer as timer
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(
|
||||
filename if not binary else binary,
|
||||
zoomin,
|
||||
from_page,
|
||||
to_page,
|
||||
callback
|
||||
)
|
||||
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin)
|
||||
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
logging.debug(f"layouts cost: {timer() - start}s")
|
||||
|
||||
start = timer()
|
||||
self._table_transformer_job(zoomin)
|
||||
callback(0.68, "Table analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._text_merge()
|
||||
tbls = self._extract_table_figure(True, zoomin, True, True)
|
||||
column_width = np.median([b["x1"] - b["x0"] for b in self.boxes])
|
||||
self._concat_downward()
|
||||
self._filter_forpages()
|
||||
callback(0.75, "Text merged ({:.2f}s)".format(timer() - start))
|
||||
|
||||
# clean mess
|
||||
if column_width < self.page_images[0].size[0] / zoomin / 2:
|
||||
logging.debug("two_column................... {} {}".format(column_width,
|
||||
self.page_images[0].size[0] / zoomin / 2))
|
||||
self.boxes = self.sort_X_by_page(self.boxes, column_width / 2)
|
||||
for b in self.boxes:
|
||||
b["text"] = re.sub(r"([\t ]|\u3000){2,}", " ", b["text"].strip())
|
||||
|
||||
def _begin(txt):
|
||||
return re.match(
|
||||
"[0-9. 一、i]*(introduction|abstract|摘要|引言|keywords|key words|关键词|background|背景|目录|前言|contents)",
|
||||
txt.lower().strip())
|
||||
|
||||
if from_page > 0:
|
||||
return {
|
||||
"title": "",
|
||||
"authors": "",
|
||||
"abstract": "",
|
||||
"sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes if
|
||||
re.match(r"(text|title)", b.get("layoutno", "text"))],
|
||||
"tables": tbls
|
||||
}
|
||||
# get title and authors
|
||||
title = ""
|
||||
authors = []
|
||||
i = 0
|
||||
while i < min(32, len(self.boxes)-1):
|
||||
b = self.boxes[i]
|
||||
i += 1
|
||||
if b.get("layoutno", "").find("title") >= 0:
|
||||
title = b["text"]
|
||||
if _begin(title):
|
||||
title = ""
|
||||
break
|
||||
for j in range(3):
|
||||
if _begin(self.boxes[i + j]["text"]):
|
||||
break
|
||||
authors.append(self.boxes[i + j]["text"])
|
||||
break
|
||||
break
|
||||
# get abstract
|
||||
abstr = ""
|
||||
i = 0
|
||||
while i + 1 < min(32, len(self.boxes)):
|
||||
b = self.boxes[i]
|
||||
i += 1
|
||||
txt = b["text"].lower().strip()
|
||||
if re.match("(abstract|摘要)", txt):
|
||||
if len(txt.split()) > 32 or len(txt) > 64:
|
||||
abstr = txt + self._line_tag(b, zoomin)
|
||||
break
|
||||
txt = self.boxes[i]["text"].lower().strip()
|
||||
if len(txt.split()) > 32 or len(txt) > 64:
|
||||
abstr = txt + self._line_tag(self.boxes[i], zoomin)
|
||||
i += 1
|
||||
break
|
||||
if not abstr:
|
||||
i = 0
|
||||
|
||||
callback(
|
||||
0.8, "Page {}~{}: Text merging finished".format(
|
||||
from_page, min(
|
||||
to_page, self.total_page)))
|
||||
for b in self.boxes:
|
||||
logging.debug("{} {}".format(b["text"], b.get("layoutno")))
|
||||
logging.debug("{}".format(tbls))
|
||||
|
||||
return {
|
||||
"title": title,
|
||||
"authors": " ".join(authors),
|
||||
"abstract": abstr,
|
||||
"sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if
|
||||
re.match(r"(text|title)", b.get("layoutno", "text"))],
|
||||
"tables": tbls
|
||||
}
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Only pdf is supported.
|
||||
The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
|
||||
"""
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
paper = {
|
||||
"title": filename,
|
||||
"authors": " ",
|
||||
"abstract": "",
|
||||
"sections": pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page)[0],
|
||||
"tables": []
|
||||
}
|
||||
else:
|
||||
pdf_parser = Pdf()
|
||||
paper = pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)
|
||||
tbls=paper["tables"]
|
||||
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
|
||||
paper["tables"] = tbls
|
||||
else:
|
||||
raise NotImplementedError("file type not supported yet(pdf supported)")
|
||||
|
||||
doc = {"docnm_kwd": filename, "authors_tks": rag_tokenizer.tokenize(paper["authors"]),
|
||||
"title_tks": rag_tokenizer.tokenize(paper["title"] if paper["title"] else filename)}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
doc["authors_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["authors_tks"])
|
||||
# is it English
|
||||
eng = lang.lower() == "english" # pdf_parser.is_english
|
||||
logging.debug("It's English.....{}".format(eng))
|
||||
|
||||
res = tokenize_table(paper["tables"], doc, eng)
|
||||
|
||||
if paper["abstract"]:
|
||||
d = copy.deepcopy(doc)
|
||||
txt = pdf_parser.remove_tag(paper["abstract"])
|
||||
d["important_kwd"] = ["abstract", "总结", "概括", "summary", "summarize"]
|
||||
d["important_tks"] = " ".join(d["important_kwd"])
|
||||
d["image"], poss = pdf_parser.crop(
|
||||
paper["abstract"], need_position=True)
|
||||
add_positions(d, poss)
|
||||
tokenize(d, txt, eng)
|
||||
res.append(d)
|
||||
|
||||
sorted_sections = paper["sections"]
|
||||
# set pivot using the most frequent type of title,
|
||||
# then merge between 2 pivot
|
||||
bull = bullets_category([txt for txt, _ in sorted_sections])
|
||||
most_level, levels = title_frequency(bull, sorted_sections)
|
||||
assert len(sorted_sections) == len(levels)
|
||||
sec_ids = []
|
||||
sid = 0
|
||||
for i, lvl in enumerate(levels):
|
||||
if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
|
||||
sid += 1
|
||||
sec_ids.append(sid)
|
||||
logging.debug("{} {} {} {}".format(lvl, sorted_sections[i][0], most_level, sid))
|
||||
|
||||
chunks = []
|
||||
last_sid = -2
|
||||
for (txt, _), sec_id in zip(sorted_sections, sec_ids):
|
||||
if sec_id == last_sid:
|
||||
if chunks:
|
||||
chunks[-1] += "\n" + txt
|
||||
continue
|
||||
chunks.append(txt)
|
||||
last_sid = sec_id
|
||||
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
||||
return res
|
||||
|
||||
|
||||
"""
|
||||
readed = [0] * len(paper["lines"])
|
||||
# find colon firstly
|
||||
i = 0
|
||||
while i + 1 < len(paper["lines"]):
|
||||
txt = pdf_parser.remove_tag(paper["lines"][i][0])
|
||||
j = i
|
||||
if txt.strip("\n").strip()[-1] not in "::":
|
||||
i += 1
|
||||
continue
|
||||
i += 1
|
||||
while i < len(paper["lines"]) and not paper["lines"][i][0]:
|
||||
i += 1
|
||||
if i >= len(paper["lines"]): break
|
||||
proj = [paper["lines"][i][0].strip()]
|
||||
i += 1
|
||||
while i < len(paper["lines"]) and paper["lines"][i][0].strip()[0] == proj[-1][0]:
|
||||
proj.append(paper["lines"][i])
|
||||
i += 1
|
||||
for k in range(j, i): readed[k] = True
|
||||
txt = txt[::-1]
|
||||
if eng:
|
||||
r = re.search(r"(.*?) ([\\.;?!]|$)", txt)
|
||||
txt = r.group(1)[::-1] if r else txt[::-1]
|
||||
else:
|
||||
r = re.search(r"(.*?) ([。?;!]|$)", txt)
|
||||
txt = r.group(1)[::-1] if r else txt[::-1]
|
||||
for p in proj:
|
||||
d = copy.deepcopy(doc)
|
||||
txt += "\n" + pdf_parser.remove_tag(p)
|
||||
d["image"], poss = pdf_parser.crop(p, need_position=True)
|
||||
add_positions(d, poss)
|
||||
tokenize(d, txt, eng)
|
||||
res.append(d)
|
||||
|
||||
i = 0
|
||||
chunk = []
|
||||
tk_cnt = 0
|
||||
def add_chunk():
|
||||
nonlocal chunk, res, doc, pdf_parser, tk_cnt
|
||||
d = copy.deepcopy(doc)
|
||||
ck = "\n".join(chunk)
|
||||
tokenize(d, pdf_parser.remove_tag(ck), pdf_parser.is_english)
|
||||
d["image"], poss = pdf_parser.crop(ck, need_position=True)
|
||||
add_positions(d, poss)
|
||||
res.append(d)
|
||||
chunk = []
|
||||
tk_cnt = 0
|
||||
|
||||
while i < len(paper["lines"]):
|
||||
if tk_cnt > 128:
|
||||
add_chunk()
|
||||
if readed[i]:
|
||||
i += 1
|
||||
continue
|
||||
readed[i] = True
|
||||
txt, layouts = paper["lines"][i]
|
||||
txt_ = pdf_parser.remove_tag(txt)
|
||||
i += 1
|
||||
cnt = num_tokens_from_string(txt_)
|
||||
if any([
|
||||
layouts.find("title") >= 0 and chunk,
|
||||
cnt + tk_cnt > 128 and tk_cnt > 32,
|
||||
]):
|
||||
add_chunk()
|
||||
chunk = [txt]
|
||||
tk_cnt = cnt
|
||||
else:
|
||||
chunk.append(txt)
|
||||
tk_cnt += cnt
|
||||
|
||||
if chunk: add_chunk()
|
||||
for i, d in enumerate(res):
|
||||
print(d)
|
||||
# d["image"].save(f"./logs/{i}.jpg")
|
||||
return res
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
96
api/app/core/rag/app/picture.py
Normal file
96
api/app/core/rag/app/picture.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import io
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.core.rag.deepdoc.vision import OCR
|
||||
from app.core.rag.nlp import rag_tokenizer, tokenize
|
||||
from app.core.rag.common.string_utils import clean_markdown_block
|
||||
|
||||
ocr = OCR()
|
||||
|
||||
# Gemini supported MIME types
|
||||
VIDEO_EXTS = [".mp4", ".mov", ".avi", ".flv", ".mpeg", ".mpg", ".webm", ".wmv", ".3gp", ".3gpp", ".mkv"]
|
||||
|
||||
|
||||
def chunk(filename, binary, lang, callback=None, vision_model=None, **kwargs):
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
|
||||
}
|
||||
eng = lang.lower() == "english"
|
||||
|
||||
if any(filename.lower().endswith(ext) for ext in VIDEO_EXTS):
|
||||
try:
|
||||
doc.update({"doc_type_kwd": "video"})
|
||||
ans = vision_model.chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename)
|
||||
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
|
||||
ans += "\n" + ans
|
||||
tokenize(doc, ans, eng)
|
||||
return [doc]
|
||||
except Exception as e:
|
||||
callback(prog=-1, msg=str(e))
|
||||
else:
|
||||
img = Image.open(io.BytesIO(binary)).convert("RGB")
|
||||
doc.update(
|
||||
{
|
||||
"image": img,
|
||||
"doc_type_kwd": "image",
|
||||
}
|
||||
)
|
||||
bxs = ocr(np.array(img))
|
||||
txt = "\n".join([t[0] for _, t in bxs if t[0]])
|
||||
callback(0.4, "Finish OCR: (%s ...)" % txt[:12])
|
||||
if (eng and len(txt.split()) > 32) or len(txt) > 32:
|
||||
tokenize(doc, txt, eng)
|
||||
callback(0.8, "OCR results is too long to use CV LLM.")
|
||||
return [doc]
|
||||
|
||||
try:
|
||||
callback(0.4, "Use CV LLM to describe the picture.")
|
||||
img_binary = io.BytesIO()
|
||||
img.save(img_binary, format="JPEG")
|
||||
img_binary.seek(0)
|
||||
ans = vision_model.describe(img_binary.read())
|
||||
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
|
||||
txt += "\n" + ans
|
||||
tokenize(doc, txt, eng)
|
||||
return [doc]
|
||||
except Exception as e:
|
||||
callback(prog=-1, msg=str(e))
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def vision_llm_chunk(binary, vision_model, prompt=None, callback=None):
|
||||
"""
|
||||
A simple wrapper to process image to markdown texts via VLM.
|
||||
|
||||
Returns:
|
||||
Simple markdown texts generated by VLM.
|
||||
"""
|
||||
callback = callback or (lambda prog, msg: None)
|
||||
|
||||
img = binary
|
||||
txt = ""
|
||||
|
||||
try:
|
||||
with io.BytesIO() as img_binary:
|
||||
try:
|
||||
img.save(img_binary, format="JPEG")
|
||||
except Exception:
|
||||
img_binary.seek(0)
|
||||
img_binary.truncate()
|
||||
img.save(img_binary, format="PNG")
|
||||
|
||||
img_binary.seek(0)
|
||||
description, token_count = vision_model.describe_with_prompt(img_binary.read(), prompt)
|
||||
ans = clean_markdown_block(description)
|
||||
txt += "\n" + ans
|
||||
return txt
|
||||
|
||||
except Exception as e:
|
||||
callback(-1, str(e))
|
||||
|
||||
return ""
|
||||
164
api/app/core/rag/app/presentation.py
Normal file
164
api/app/core/rag/app/presentation.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import copy
|
||||
import re
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
from app.core.rag.nlp import tokenize, is_english
|
||||
from app.core.rag.nlp import rag_tokenizer
|
||||
from app.core.rag.deepdoc.parser import PdfParser, PptParser, PlainParser
|
||||
from PyPDF2 import PdfReader as pdf2_read
|
||||
from app.core.rag.app.naive import by_plaintext, PARSERS
|
||||
|
||||
class Ppt(PptParser):
|
||||
def __call__(self, fnm, from_page, to_page, callback=None):
|
||||
txts = super().__call__(fnm, from_page, to_page)
|
||||
|
||||
callback(0.5, "Text extraction finished.")
|
||||
import aspose.slides as slides
|
||||
import aspose.pydrawing as drawing
|
||||
imgs = []
|
||||
with slides.Presentation(BytesIO(fnm)) as presentation:
|
||||
for i, slide in enumerate(presentation.slides[from_page: to_page]):
|
||||
try:
|
||||
with BytesIO() as buffered:
|
||||
slide.get_thumbnail(
|
||||
0.1, 0.1).save(
|
||||
buffered, drawing.imaging.ImageFormat.jpeg)
|
||||
buffered.seek(0)
|
||||
imgs.append(Image.open(buffered).copy())
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(f'ppt parse error at page {i+1}, original error: {str(e)}') from e
|
||||
assert len(imgs) == len(
|
||||
txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
|
||||
callback(0.9, "Image extraction finished")
|
||||
self.is_english = is_english(txts)
|
||||
return [(txts[i], imgs[i]) for i in range(len(txts))]
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __garbage(self, txt):
|
||||
txt = txt.lower().strip()
|
||||
if re.match(r"[0-9\.,%/-]+$", txt):
|
||||
return True
|
||||
if len(txt) < 3:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
from timeit import default_timer as timer
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(filename if not binary else binary,
|
||||
zoomin, from_page, to_page, callback)
|
||||
callback(msg="Page {}~{}: OCR finished ({:.2f}s)".format(from_page, min(to_page, self.total_page), timer() - start))
|
||||
assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(
|
||||
len(self.boxes), len(self.page_images))
|
||||
res = []
|
||||
for i in range(len(self.boxes)):
|
||||
lines = "\n".join([b["text"] for b in self.boxes[i]
|
||||
if not self.__garbage(b["text"])])
|
||||
res.append((lines, self.page_images[i]))
|
||||
callback(0.9, "Page {}~{}: Parsing finished".format(
|
||||
from_page, min(to_page, self.total_page)))
|
||||
return res, []
|
||||
|
||||
|
||||
class PlainPdf(PlainParser):
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, callback=None, **kwargs):
|
||||
self.pdf = pdf2_read(filename if not binary else BytesIO(binary))
|
||||
page_txt = []
|
||||
for page in self.pdf.pages[from_page: to_page]:
|
||||
page_txt.append(page.extract_text())
|
||||
callback(0.9, "Parsing finished")
|
||||
return [(txt, None) for txt in page_txt], []
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, vision_model=None, parser_config=None, **kwargs):
|
||||
"""
|
||||
The supported file formats are pdf, pptx.
|
||||
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
|
||||
PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
|
||||
"""
|
||||
if parser_config is None:
|
||||
parser_config = {}
|
||||
eng = lang.lower() == "english"
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
res = []
|
||||
if re.search(r"\.pptx?$", filename, re.IGNORECASE):
|
||||
if not binary:
|
||||
with open(filename, "rb") as f:
|
||||
binary = f.read()
|
||||
ppt_parser = Ppt()
|
||||
for pn, (txt, img) in enumerate(ppt_parser(
|
||||
filename if not binary else binary, from_page, 1000000, callback)):
|
||||
d = copy.deepcopy(doc)
|
||||
pn += from_page
|
||||
d["image"] = img
|
||||
d["doc_type_kwd"] = "image"
|
||||
d["page_num_int"] = [pn + 1]
|
||||
d["top_int"] = [0]
|
||||
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
|
||||
tokenize(d, txt, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
|
||||
if isinstance(layout_recognizer, bool):
|
||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
||||
|
||||
name = layout_recognizer.strip().lower()
|
||||
parser = PARSERS.get(name, by_plaintext)
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
sections, _, _ = parser(
|
||||
filename=filename,
|
||||
binary=binary,
|
||||
from_page=from_page,
|
||||
to_page=to_page,
|
||||
lang=lang,
|
||||
callback=callback,
|
||||
vision_model=vision_model,
|
||||
pdf_cls=Pdf,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not sections:
|
||||
return []
|
||||
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
for pn, (txt, img) in enumerate(sections):
|
||||
d = copy.deepcopy(doc)
|
||||
pn += from_page
|
||||
if img:
|
||||
d["image"] = img
|
||||
d["page_num_int"] = [pn + 1]
|
||||
d["top_int"] = [0]
|
||||
d["position_int"] = [(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
|
||||
tokenize(d, txt, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
raise NotImplementedError(
|
||||
"file type not supported yet(pptx, pdf supported)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(a, b):
|
||||
pass
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
455
api/app/core/rag/app/qa.py
Normal file
455
api/app/core/rag/app/qa.py
Normal file
@@ -0,0 +1,455 @@
|
||||
import logging
|
||||
import re
|
||||
import csv
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from timeit import default_timer as timer
|
||||
from openpyxl import load_workbook
|
||||
|
||||
from app.core.rag.deepdoc.parser.utils import get_text
|
||||
from app.core.rag.nlp import is_english, random_choices, qbullets_category, add_positions, has_qbullet, docx_question_level
|
||||
from app.core.rag.nlp import rag_tokenizer, tokenize_table, concat_img
|
||||
from app.core.rag.deepdoc.parser import PdfParser, ExcelParser, DocxParser
|
||||
from docx import Document
|
||||
from PIL import Image
|
||||
from markdown import markdown
|
||||
|
||||
from app.core.rag.common.float_utils import get_float
|
||||
|
||||
|
||||
class Excel(ExcelParser):
|
||||
def __call__(self, fnm, binary=None, callback=None):
|
||||
if not binary:
|
||||
wb = load_workbook(fnm)
|
||||
else:
|
||||
wb = load_workbook(BytesIO(binary))
|
||||
total = 0
|
||||
for sheetname in wb.sheetnames:
|
||||
total += len(list(wb[sheetname].rows))
|
||||
|
||||
res, fails = [], []
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
rows = list(ws.rows)
|
||||
for i, r in enumerate(rows):
|
||||
q, a = "", ""
|
||||
for cell in r:
|
||||
if not cell.value:
|
||||
continue
|
||||
if not q:
|
||||
q = str(cell.value)
|
||||
elif not a:
|
||||
a = str(cell.value)
|
||||
else:
|
||||
break
|
||||
if q and a:
|
||||
res.append((q, a))
|
||||
else:
|
||||
fails.append(str(i + 1))
|
||||
if len(res) % 999 == 0:
|
||||
callback(len(res) *
|
||||
0.6 /
|
||||
total, ("Extract pairs: {}".format(len(res)) +
|
||||
(f"{len(fails)} failure, line: %s..." %
|
||||
(",".join(fails[:3])) if fails else "")))
|
||||
|
||||
callback(0.6, ("Extract pairs: {}. ".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
self.is_english = is_english(
|
||||
[rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
|
||||
return res
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(
|
||||
filename if not binary else binary,
|
||||
zoomin,
|
||||
from_page,
|
||||
to_page,
|
||||
callback
|
||||
)
|
||||
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
|
||||
logging.debug("OCR({}~{}): {:.2f}s".format(from_page, to_page, timer() - start))
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin, drop=False)
|
||||
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._table_transformer_job(zoomin)
|
||||
callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._text_merge()
|
||||
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
|
||||
tbls = self._extract_table_figure(True, zoomin, True, True)
|
||||
#self._naive_vertical_merge()
|
||||
# self._concat_downward()
|
||||
#self._filter_forpages()
|
||||
logging.debug("layouts: {}".format(timer() - start))
|
||||
sections = [b["text"] for b in self.boxes]
|
||||
bull_x0_list = []
|
||||
q_bull, reg = qbullets_category(sections)
|
||||
if q_bull == -1:
|
||||
raise ValueError("Unable to recognize Q&A structure.")
|
||||
qai_list = []
|
||||
last_q, last_a, last_tag = '', '', ''
|
||||
last_index = -1
|
||||
last_box = {'text':''}
|
||||
last_bull = None
|
||||
def sort_key(element):
|
||||
tbls_pn = element[1][0][0]
|
||||
tbls_top = element[1][0][3]
|
||||
return tbls_pn, tbls_top
|
||||
tbls.sort(key=sort_key)
|
||||
tbl_index = 0
|
||||
last_pn, last_bottom = 0, 0
|
||||
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', ''
|
||||
for box in self.boxes:
|
||||
section, line_tag = box['text'], self._line_tag(box, zoomin)
|
||||
has_bull, index = has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list)
|
||||
last_box, last_index, last_bull = box, index, has_bull
|
||||
line_pn = get_float(line_tag.lstrip('@@').split('\t')[0])
|
||||
line_top = get_float(line_tag.rstrip('##').split('\t')[3])
|
||||
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
|
||||
if not has_bull: # No question bullet
|
||||
if not last_q:
|
||||
if tbl_pn < line_pn or (tbl_pn == line_pn and tbl_top <= line_top): # image passed
|
||||
tbl_index += 1
|
||||
continue
|
||||
else:
|
||||
sum_tag = line_tag
|
||||
sum_section = section
|
||||
while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \
|
||||
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the middle of current answer
|
||||
sum_tag = f'{tbl_tag}{sum_tag}'
|
||||
sum_section = f'{tbl_text}{sum_section}'
|
||||
tbl_index += 1
|
||||
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
|
||||
last_a = f'{last_a}{sum_section}'
|
||||
last_tag = f'{last_tag}{sum_tag}'
|
||||
else:
|
||||
if last_q:
|
||||
while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \
|
||||
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the end of last answer
|
||||
last_tag = f'{last_tag}{tbl_tag}'
|
||||
last_a = f'{last_a}{tbl_text}'
|
||||
tbl_index += 1
|
||||
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
|
||||
image, poss = self.crop(last_tag, need_position=True)
|
||||
qai_list.append((last_q, last_a, image, poss))
|
||||
last_q, last_a, last_tag = '', '', ''
|
||||
last_q = has_bull.group()
|
||||
_, end = has_bull.span()
|
||||
last_a = section[end:]
|
||||
last_tag = line_tag
|
||||
last_bottom = float(line_tag.rstrip('##').split('\t')[4])
|
||||
last_pn = line_pn
|
||||
if last_q:
|
||||
qai_list.append((last_q, last_a, *self.crop(last_tag, need_position=True)))
|
||||
return qai_list, tbls
|
||||
|
||||
def get_tbls_info(self, tbls, tbl_index):
|
||||
if tbl_index >= len(tbls):
|
||||
return 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', ''
|
||||
tbl_pn = tbls[tbl_index][1][0][0]+1
|
||||
tbl_left = tbls[tbl_index][1][0][1]
|
||||
tbl_right = tbls[tbl_index][1][0][2]
|
||||
tbl_top = tbls[tbl_index][1][0][3]
|
||||
tbl_bottom = tbls[tbl_index][1][0][4]
|
||||
tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
|
||||
.format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom)
|
||||
_tbl_text = ''.join(tbls[tbl_index][0][1])
|
||||
return tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, _tbl_text
|
||||
|
||||
|
||||
class Docx(DocxParser):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_picture(self, document, paragraph):
|
||||
img = paragraph._element.xpath('.//pic:pic')
|
||||
if not img:
|
||||
return None
|
||||
img = img[0]
|
||||
embed = img.xpath('.//a:blip/@r:embed')[0]
|
||||
related_part = document.part.related_parts[embed]
|
||||
image = related_part.image
|
||||
image = Image.open(BytesIO(image.blob)).convert('RGB')
|
||||
return image
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None):
|
||||
self.doc = Document(
|
||||
filename) if not binary else Document(BytesIO(binary))
|
||||
pn = 0
|
||||
last_answer, last_image = "", None
|
||||
question_stack, level_stack = [], []
|
||||
qai_list = []
|
||||
for p in self.doc.paragraphs:
|
||||
if pn > to_page:
|
||||
break
|
||||
question_level, p_text = 0, ''
|
||||
if from_page <= pn < to_page and p.text.strip():
|
||||
question_level, p_text = docx_question_level(p)
|
||||
if not question_level or question_level > 6: # not a question
|
||||
last_answer = f'{last_answer}\n{p_text}'
|
||||
current_image = self.get_picture(self.doc, p)
|
||||
last_image = concat_img(last_image, current_image)
|
||||
else: # is a question
|
||||
if last_answer or last_image:
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
qai_list.append((sum_question, last_answer, last_image))
|
||||
last_answer, last_image = '', None
|
||||
|
||||
i = question_level
|
||||
while question_stack and i <= level_stack[-1]:
|
||||
question_stack.pop()
|
||||
level_stack.pop()
|
||||
question_stack.append(p_text)
|
||||
level_stack.append(question_level)
|
||||
for run in p.runs:
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
continue
|
||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||
pn += 1
|
||||
if last_answer:
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
qai_list.append((sum_question, last_answer, last_image))
|
||||
|
||||
tbls = []
|
||||
for tb in self.doc.tables:
|
||||
html= "<table>"
|
||||
for r in tb.rows:
|
||||
html += "<tr>"
|
||||
i = 0
|
||||
while i < len(r.cells):
|
||||
span = 1
|
||||
c = r.cells[i]
|
||||
for j in range(i+1, len(r.cells)):
|
||||
if c.text == r.cells[j].text:
|
||||
span += 1
|
||||
i = j
|
||||
i += 1
|
||||
html += f"<td>{c.text}</td>" if span == 1 else f"<td colspan='{span}'>{c.text}</td>"
|
||||
html += "</tr>"
|
||||
html += "</table>"
|
||||
tbls.append(((None, html), ""))
|
||||
return qai_list, tbls
|
||||
|
||||
|
||||
def rmPrefix(txt):
|
||||
return re.sub(
|
||||
r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def beAdocPdf(d, q, a, eng, image, poss):
|
||||
qprefix = "Question: " if eng else "问题:"
|
||||
aprefix = "Answer: " if eng else "回答:"
|
||||
d["content_with_weight"] = "\t".join(
|
||||
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(q)
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
if image:
|
||||
d["image"] = image
|
||||
d["doc_type_kwd"] = "image"
|
||||
add_positions(d, poss)
|
||||
return d
|
||||
|
||||
|
||||
def beAdocDocx(d, q, a, eng, image, row_num=-1):
|
||||
qprefix = "Question: " if eng else "问题:"
|
||||
aprefix = "Answer: " if eng else "回答:"
|
||||
d["content_with_weight"] = "\t".join(
|
||||
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(q)
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
if image:
|
||||
d["image"] = image
|
||||
d["doc_type_kwd"] = "image"
|
||||
if row_num >= 0:
|
||||
d["top_int"] = [row_num]
|
||||
return d
|
||||
|
||||
|
||||
def beAdoc(d, q, a, eng, row_num=-1):
|
||||
qprefix = "Question: " if eng else "问题:"
|
||||
aprefix = "Answer: " if eng else "回答:"
|
||||
d["content_with_weight"] = "\t".join(
|
||||
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(q)
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
if row_num >= 0:
|
||||
d["top_int"] = [row_num]
|
||||
return d
|
||||
|
||||
|
||||
def mdQuestionLevel(s):
|
||||
match = re.match(r'#*', s)
|
||||
return (len(match.group(0)), s.lstrip('#').lstrip()) if match else (0, s)
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Excel and csv(txt) format files are supported.
|
||||
If the file is in excel format, there should be 2 column question and answer without header.
|
||||
And question column is ahead of answer column.
|
||||
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
|
||||
|
||||
If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate question and answer.
|
||||
|
||||
All the deformed lines will be ignored.
|
||||
Every pair of Q&A will be treated as a chunk.
|
||||
"""
|
||||
eng = lang.lower() == "english"
|
||||
res = []
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
excel_parser = Excel()
|
||||
for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)):
|
||||
res.append(beAdoc(deepcopy(doc), q, a, eng, ii))
|
||||
return res
|
||||
|
||||
elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
lines = txt.split("\n")
|
||||
comma, tab = 0, 0
|
||||
for line in lines:
|
||||
if len(line.split(",")) == 2:
|
||||
comma += 1
|
||||
if len(line.split("\t")) == 2:
|
||||
tab += 1
|
||||
delimiter = "\t" if tab >= comma else ","
|
||||
|
||||
fails = []
|
||||
question, answer = "", ""
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
arr = lines[i].split(delimiter)
|
||||
if len(arr) != 2:
|
||||
if question:
|
||||
answer += "\n" + lines[i]
|
||||
else:
|
||||
fails.append(str(i+1))
|
||||
elif len(arr) == 2:
|
||||
if question and answer:
|
||||
res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
|
||||
question, answer = arr
|
||||
i += 1
|
||||
if len(res) % 999 == 0:
|
||||
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
if question:
|
||||
res.append(beAdoc(deepcopy(doc), question, answer, eng, len(lines)))
|
||||
|
||||
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
return res
|
||||
|
||||
elif re.search(r"\.(csv)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
lines = txt.split("\n")
|
||||
delimiter = "\t" if any("\t" in line for line in lines) else ","
|
||||
|
||||
fails = []
|
||||
question, answer = "", ""
|
||||
res = []
|
||||
reader = csv.reader(lines, delimiter=delimiter)
|
||||
|
||||
for i, row in enumerate(reader):
|
||||
if len(row) != 2:
|
||||
if question:
|
||||
answer += "\n" + lines[i]
|
||||
else:
|
||||
fails.append(str(i + 1))
|
||||
elif len(row) == 2:
|
||||
if question and answer:
|
||||
res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
|
||||
question, answer = row
|
||||
if len(res) % 999 == 0:
|
||||
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
if question:
|
||||
res.append(beAdoc(deepcopy(doc), question, answer, eng, len(list(reader))))
|
||||
|
||||
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
return res
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
pdf_parser = Pdf()
|
||||
qai_list, tbls = pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)
|
||||
for q, a, image, poss in qai_list:
|
||||
res.append(beAdocPdf(deepcopy(doc), q, a, eng, image, poss))
|
||||
return res
|
||||
|
||||
elif re.search(r"\.(md|markdown)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
lines = txt.split("\n")
|
||||
_last_question, last_answer = "", ""
|
||||
question_stack, level_stack = [], []
|
||||
code_block = False
|
||||
for index, line in enumerate(lines):
|
||||
if line.strip().startswith('```'):
|
||||
code_block = not code_block
|
||||
question_level, question = 0, ''
|
||||
if not code_block:
|
||||
question_level, question = mdQuestionLevel(line)
|
||||
|
||||
if not question_level or question_level > 6: # not a question
|
||||
last_answer = f'{last_answer}\n{line}'
|
||||
else: # is a question
|
||||
if last_answer.strip():
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
|
||||
last_answer = ''
|
||||
|
||||
i = question_level
|
||||
while question_stack and i <= level_stack[-1]:
|
||||
question_stack.pop()
|
||||
level_stack.pop()
|
||||
question_stack.append(question)
|
||||
level_stack.append(question_level)
|
||||
if last_answer.strip():
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
|
||||
return res
|
||||
|
||||
elif re.search(r"\.docx$", filename, re.IGNORECASE):
|
||||
docx_parser = Docx()
|
||||
qai_list, tbls = docx_parser(filename, binary,
|
||||
from_page=0, to_page=10000, callback=callback)
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
for i, (q, a, image) in enumerate(qai_list):
|
||||
res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i))
|
||||
return res
|
||||
|
||||
raise NotImplementedError(
|
||||
"Excel, csv(txt), pdf, markdown and docx format files are supported.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
0
api/app/core/rag/common/__init__.py
Normal file
0
api/app/core/rag/common/__init__.py
Normal file
106
api/app/core/rag/common/connection_utils.py
Normal file
106
api/app/core/rag/common/connection_utils.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
from typing import Any, Callable, Coroutine, Optional, Type, Union
|
||||
import asyncio
|
||||
import trio
|
||||
from functools import wraps
|
||||
from flask import make_response, jsonify
|
||||
from .constants import RetCode
|
||||
|
||||
TimeoutException = Union[Type[BaseException], BaseException]
|
||||
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
|
||||
|
||||
|
||||
def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None,
|
||||
on_timeout: Optional[OnTimeoutCallback] = None):
|
||||
if isinstance(seconds, str):
|
||||
seconds = float(seconds)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
result_queue = queue.Queue(maxsize=1)
|
||||
|
||||
def target():
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
result_queue.put(result)
|
||||
except Exception as e:
|
||||
result_queue.put(e)
|
||||
|
||||
thread = threading.Thread(target=target)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
for a in range(attempts):
|
||||
try:
|
||||
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
|
||||
result = result_queue.get(timeout=seconds)
|
||||
else:
|
||||
result = result_queue.get()
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
except queue.Empty:
|
||||
pass
|
||||
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.")
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs) -> Any:
|
||||
if seconds is None:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
for a in range(attempts):
|
||||
try:
|
||||
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
|
||||
with trio.fail_after(seconds):
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
return await func(*args, **kwargs)
|
||||
except trio.TooSlowError:
|
||||
if a < attempts - 1:
|
||||
continue
|
||||
if on_timeout is not None:
|
||||
if callable(on_timeout):
|
||||
result = on_timeout()
|
||||
if isinstance(result, Coroutine):
|
||||
return await result
|
||||
return result
|
||||
return on_timeout
|
||||
|
||||
if exception is None:
|
||||
raise TimeoutError(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
|
||||
|
||||
if isinstance(exception, BaseException):
|
||||
raise exception
|
||||
|
||||
if isinstance(exception, type) and issubclass(exception, BaseException):
|
||||
raise exception(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
|
||||
|
||||
raise RuntimeError("Invalid exception type provided")
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None):
|
||||
result_dict = {"code": code, "message": message, "data": data}
|
||||
response_dict = {}
|
||||
for key, value in result_dict.items():
|
||||
if value is None and key != "code":
|
||||
continue
|
||||
else:
|
||||
response_dict[key] = value
|
||||
response = make_response(jsonify(response_dict))
|
||||
if auth:
|
||||
response.headers["Authorization"] = auth
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Method"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
response.headers["Access-Control-Expose-Headers"] = "Authorization"
|
||||
return response
|
||||
180
api/app/core/rag/common/constants.py
Normal file
180
api/app/core/rag/common/constants.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from enum import Enum, IntEnum
|
||||
from strenum import StrEnum
|
||||
|
||||
SERVICE_CONF = "service_conf.yaml"
|
||||
RAG_SERVICE_NAME = "rag"
|
||||
|
||||
class CustomEnum(Enum):
|
||||
@classmethod
|
||||
def valid(cls, value):
|
||||
try:
|
||||
cls(value)
|
||||
return True
|
||||
except BaseException:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [member.value for member in cls.__members__.values()]
|
||||
|
||||
@classmethod
|
||||
def names(cls):
|
||||
return [member.name for member in cls.__members__.values()]
|
||||
|
||||
|
||||
class RetCode(IntEnum, CustomEnum):
|
||||
SUCCESS = 0
|
||||
NOT_EFFECTIVE = 10
|
||||
EXCEPTION_ERROR = 100
|
||||
ARGUMENT_ERROR = 101
|
||||
DATA_ERROR = 102
|
||||
OPERATING_ERROR = 103
|
||||
CONNECTION_ERROR = 105
|
||||
RUNNING = 106
|
||||
PERMISSION_ERROR = 108
|
||||
AUTHENTICATION_ERROR = 109
|
||||
UNAUTHORIZED = 401
|
||||
SERVER_ERROR = 500
|
||||
FORBIDDEN = 403
|
||||
NOT_FOUND = 404
|
||||
|
||||
|
||||
class StatusEnum(Enum):
|
||||
VALID = "1"
|
||||
INVALID = "0"
|
||||
|
||||
|
||||
class ActiveEnum(Enum):
|
||||
ACTIVE = "1"
|
||||
INACTIVE = "0"
|
||||
|
||||
|
||||
class LLMType(StrEnum):
|
||||
CHAT = 'chat'
|
||||
EMBEDDING = 'embedding'
|
||||
SPEECH2TEXT = 'speech2text'
|
||||
IMAGE2TEXT = 'image2text'
|
||||
RERANK = 'rerank'
|
||||
TTS = 'tts'
|
||||
|
||||
|
||||
class TaskStatus(StrEnum):
|
||||
UNSTART = "0"
|
||||
RUNNING = "1"
|
||||
CANCEL = "2"
|
||||
DONE = "3"
|
||||
FAIL = "4"
|
||||
SCHEDULE = "5"
|
||||
|
||||
|
||||
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL,
|
||||
TaskStatus.SCHEDULE}
|
||||
|
||||
|
||||
class ParserType(StrEnum):
|
||||
PRESENTATION = "presentation"
|
||||
LAWS = "laws"
|
||||
MANUAL = "manual"
|
||||
PAPER = "paper"
|
||||
RESUME = "resume"
|
||||
BOOK = "book"
|
||||
QA = "qa"
|
||||
TABLE = "table"
|
||||
NAIVE = "naive"
|
||||
PICTURE = "picture"
|
||||
ONE = "one"
|
||||
AUDIO = "audio"
|
||||
EMAIL = "email"
|
||||
KG = "knowledge_graph"
|
||||
TAG = "tag"
|
||||
|
||||
|
||||
class FileSource(StrEnum):
|
||||
LOCAL = ""
|
||||
KNOWLEDGEBASE = "knowledgebase"
|
||||
S3 = "s3"
|
||||
NOTION = "notion"
|
||||
DISCORD = "discord"
|
||||
CONFLUENCE = "confluence"
|
||||
GMAIL = "gmail"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
JIRA = "jira"
|
||||
SHAREPOINT = "sharepoint"
|
||||
SLACK = "slack"
|
||||
TEAMS = "teams"
|
||||
|
||||
|
||||
class PipelineTaskType(StrEnum):
|
||||
PARSE = "Parse"
|
||||
DOWNLOAD = "Download"
|
||||
RAPTOR = "RAPTOR"
|
||||
GRAPH_RAG = "GraphRAG"
|
||||
MINDMAP = "Mindmap"
|
||||
|
||||
|
||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR,
|
||||
PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
|
||||
|
||||
class MCPServerType(StrEnum):
|
||||
SSE = "sse"
|
||||
STREAMABLE_HTTP = "streamable-http"
|
||||
|
||||
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
||||
|
||||
class Storage(Enum):
|
||||
MINIO = 1
|
||||
AZURE_SPN = 2
|
||||
AZURE_SAS = 3
|
||||
AWS_S3 = 4
|
||||
OSS = 5
|
||||
OPENDAL = 6
|
||||
|
||||
# environment
|
||||
# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT"
|
||||
# ENV_RAG_SECRET_KEY = "RAG_SECRET_KEY"
|
||||
# ENV_REGISTER_ENABLED = "REGISTER_ENABLED"
|
||||
# ENV_DOC_ENGINE = "DOC_ENGINE"
|
||||
# ENV_SANDBOX_ENABLED = "SANDBOX_ENABLED"
|
||||
# ENV_SANDBOX_HOST = "SANDBOX_HOST"
|
||||
# ENV_MAX_CONTENT_LENGTH = "MAX_CONTENT_LENGTH"
|
||||
# ENV_COMPONENT_EXEC_TIMEOUT = "COMPONENT_EXEC_TIMEOUT"
|
||||
# ENV_TRINO_USE_TLS = "TRINO_USE_TLS"
|
||||
# ENV_MAX_FILE_NUM_PER_USER = "MAX_FILE_NUM_PER_USER"
|
||||
# ENV_MACOS = "MACOS"
|
||||
# ENV_RAG_DEBUGPY_LISTEN = "RAG_DEBUGPY_LISTEN"
|
||||
# ENV_WERKZEUG_RUN_MAIN = "WERKZEUG_RUN_MAIN"
|
||||
# ENV_DISABLE_SDK = "DISABLE_SDK"
|
||||
# ENV_ENABLE_TIMEOUT_ASSERTION = "ENABLE_TIMEOUT_ASSERTION"
|
||||
# ENV_LOG_LEVELS = "LOG_LEVELS"
|
||||
# ENV_TENSORRT_DLA_SVR = "TENSORRT_DLA_SVR"
|
||||
# ENV_OCR_GPU_MEM_LIMIT_MB = "OCR_GPU_MEM_LIMIT_MB"
|
||||
# ENV_OCR_ARENA_EXTEND_STRATEGY = "OCR_ARENA_EXTEND_STRATEGY"
|
||||
# ENV_MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK = "MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK"
|
||||
# ENV_MAX_MAX_CONCURRENT_CHATS = "MAX_CONCURRENT_CHATS"
|
||||
# ENV_RAG_MCP_BASE_URL = "RAG_MCP_BASE_URL"
|
||||
# ENV_RAG_MCP_HOST = "RAG_MCP_HOST"
|
||||
# ENV_RAG_MCP_PORT = "RAG_MCP_PORT"
|
||||
# ENV_RAG_MCP_LAUNCH_MODE = "RAG_MCP_LAUNCH_MODE"
|
||||
# ENV_RAG_MCP_HOST_API_KEY = "RAG_MCP_HOST_API_KEY"
|
||||
# ENV_MINERU_EXECUTABLE = "MINERU_EXECUTABLE"
|
||||
# ENV_MINERU_APISERVER = "MINERU_APISERVER"
|
||||
# ENV_MINERU_OUTPUT_DIR = "MINERU_OUTPUT_DIR"
|
||||
# ENV_MINERU_BACKEND = "MINERU_BACKEND"
|
||||
# ENV_MINERU_DELETE_OUTPUT = "MINERU_DELETE_OUTPUT"
|
||||
# ENV_TCADP_OUTPUT_DIR = "TCADP_OUTPUT_DIR"
|
||||
# ENV_LM_TIMEOUT_SECONDS = "LM_TIMEOUT_SECONDS"
|
||||
# ENV_LLM_MAX_RETRIES = "LLM_MAX_RETRIES"
|
||||
# ENV_LLM_BASE_DELAY = "LLM_BASE_DELAY"
|
||||
# ENV_OLLAMA_KEEP_ALIVE = "OLLAMA_KEEP_ALIVE"
|
||||
# ENV_DOC_BULK_SIZE = "DOC_BULK_SIZE"
|
||||
# ENV_EMBEDDING_BATCH_SIZE = "EMBEDDING_BATCH_SIZE"
|
||||
# ENV_MAX_CONCURRENT_TASKS = "MAX_CONCURRENT_TASKS"
|
||||
# ENV_MAX_CONCURRENT_CHUNK_BUILDERS = "MAX_CONCURRENT_CHUNK_BUILDERS"
|
||||
# ENV_MAX_CONCURRENT_MINIO = "MAX_CONCURRENT_MINIO"
|
||||
# ENV_WORKER_HEARTBEAT_TIMEOUT = "WORKER_HEARTBEAT_TIMEOUT"
|
||||
# ENV_TRACE_MALLOC_ENABLED = "TRACE_MALLOC_ENABLED"
|
||||
|
||||
PAGERANK_FLD = "pagerank_fea"
|
||||
SVR_QUEUE_NAME = "rag_svr_queue"
|
||||
SVR_CONSUMER_GROUP_NAME = "rag_svr_task_broker"
|
||||
TAG_FLD = "tag_feas"
|
||||
28
api/app/core/rag/common/file_utils.py
Normal file
28
api/app/core/rag/common/file_utils.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import os
|
||||
|
||||
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
||||
|
||||
|
||||
def get_project_base_directory(*args):
|
||||
global PROJECT_BASE
|
||||
if PROJECT_BASE is None:
|
||||
PROJECT_BASE = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
os.pardir,
|
||||
os.pardir,
|
||||
os.pardir,
|
||||
os.pardir,
|
||||
)
|
||||
)
|
||||
|
||||
if args:
|
||||
return os.path.join(PROJECT_BASE, *args)
|
||||
return PROJECT_BASE
|
||||
|
||||
|
||||
def traversal_files(base):
|
||||
for root, ds, fs in os.walk(base):
|
||||
for f in fs:
|
||||
fullname = os.path.join(root, f)
|
||||
yield fullname
|
||||
30
api/app/core/rag/common/float_utils.py
Normal file
30
api/app/core/rag/common/float_utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
def get_float(v):
|
||||
"""
|
||||
Convert a value to float, handling None and exceptions gracefully.
|
||||
|
||||
Attempts to convert the input value to a float. If the value is None or
|
||||
cannot be converted to float, returns negative infinity as a default value.
|
||||
|
||||
Args:
|
||||
v: The value to convert to float. Can be any type that float() accepts,
|
||||
or None.
|
||||
|
||||
Returns:
|
||||
float: The converted float value if successful, otherwise float('-inf').
|
||||
|
||||
Examples:
|
||||
>>> get_float("3.14")
|
||||
3.14
|
||||
>>> get_float(None)
|
||||
-inf
|
||||
>>> get_float("invalid")
|
||||
-inf
|
||||
>>> get_float(42)
|
||||
42.0
|
||||
"""
|
||||
if v is None:
|
||||
return float('-inf')
|
||||
try:
|
||||
return float(v)
|
||||
except Exception:
|
||||
return float('-inf')
|
||||
92
api/app/core/rag/common/misc_utils.py
Normal file
92
api/app/core/rag/common/misc_utils.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import uuid
|
||||
import requests
|
||||
import threading
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
|
||||
def get_uuid():
|
||||
return uuid.uuid1().hex
|
||||
|
||||
|
||||
def download_img(url):
|
||||
if not url:
|
||||
return ""
|
||||
response = requests.get(url)
|
||||
return "data:" + \
|
||||
response.headers.get('Content-Type', 'image/jpg') + ";" + \
|
||||
"base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
|
||||
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
|
||||
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
||||
|
||||
def convert_bytes(size_in_bytes: int) -> str:
|
||||
"""
|
||||
Format size in bytes.
|
||||
"""
|
||||
if size_in_bytes == 0:
|
||||
return "0 B"
|
||||
|
||||
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
||||
i = 0
|
||||
size = float(size_in_bytes)
|
||||
|
||||
while size >= 1024 and i < len(units) - 1:
|
||||
size /= 1024
|
||||
i += 1
|
||||
|
||||
if i == 0 or size >= 100:
|
||||
return f"{size:.0f} {units[i]}"
|
||||
elif size >= 10:
|
||||
return f"{size:.1f} {units[i]}"
|
||||
else:
|
||||
return f"{size:.2f} {units[i]}"
|
||||
|
||||
|
||||
def once(func):
|
||||
"""
|
||||
A thread-safe decorator that ensures the decorated function runs exactly once,
|
||||
caching and returning its result for all subsequent calls. This prevents
|
||||
race conditions in multi-thread environments by using a lock to protect
|
||||
the execution state.
|
||||
|
||||
Args:
|
||||
func (callable): The function to be executed only once.
|
||||
|
||||
Returns:
|
||||
callable: A wrapper function that executes `func` on the first call
|
||||
and returns the cached result thereafter.
|
||||
|
||||
Example:
|
||||
@once
|
||||
def compute_expensive_value():
|
||||
print("Computing...")
|
||||
return 42
|
||||
|
||||
# First call: executes and prints
|
||||
# Subsequent calls: return 42 without executing
|
||||
"""
|
||||
executed = False
|
||||
result = None
|
||||
lock = threading.Lock()
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal executed, result
|
||||
with lock:
|
||||
if not executed:
|
||||
executed = True
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
@once
|
||||
def pip_install_torch():
|
||||
device = os.getenv("DEVICE", "cpu")
|
||||
if device=="cpu":
|
||||
return
|
||||
logging.info("Installing pytorch")
|
||||
pkg_names = ["torch>=2.5.0,<3.0.0"]
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
|
||||
2
api/app/core/rag/common/settings.py
Normal file
2
api/app/core/rag/common/settings.py
Normal file
@@ -0,0 +1,2 @@
|
||||
PARALLEL_DEVICES: int = 0
|
||||
|
||||
57
api/app/core/rag/common/string_utils.py
Normal file
57
api/app/core/rag/common/string_utils.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import re
|
||||
|
||||
|
||||
def remove_redundant_spaces(txt: str):
|
||||
"""
|
||||
Remove redundant spaces around punctuation marks while preserving meaningful spaces.
|
||||
|
||||
This function performs two main operations:
|
||||
1. Remove spaces after left-boundary characters (opening brackets, etc.)
|
||||
2. Remove spaces before right-boundary characters (closing brackets, punctuation, etc.)
|
||||
|
||||
Args:
|
||||
txt (str): Input text to process
|
||||
|
||||
Returns:
|
||||
str: Text with redundant spaces removed
|
||||
"""
|
||||
# First pass: Remove spaces after left-boundary characters
|
||||
# Matches: [non-alphanumeric-and-specific-right-punctuation] + [non-space]
|
||||
# Removes spaces after characters like '(', '<', and other non-alphanumeric chars
|
||||
# Examples:
|
||||
# "( test" → "(test"
|
||||
txt = re.sub(r"([^a-z0-9.,\)>]) +([^ ])", r"\1\2", txt, flags=re.IGNORECASE)
|
||||
|
||||
# Second pass: Remove spaces before right-boundary characters
|
||||
# Matches: [non-space] + [non-alphanumeric-and-specific-left-punctuation]
|
||||
# Removes spaces before characters like non-')', non-',', non-'.', and non-alphanumeric chars
|
||||
# Examples:
|
||||
# "world !" → "world!"
|
||||
return re.sub(r"([^ ]) +([^a-z0-9.,\(<])", r"\1\2", txt, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def clean_markdown_block(text):
|
||||
"""
|
||||
Remove Markdown code block syntax from the beginning and end of text.
|
||||
|
||||
This function cleans Markdown code blocks by removing:
|
||||
- Opening ```Markdown tags (with optional whitespace and newlines)
|
||||
- Closing ``` tags (with optional whitespace and newlines)
|
||||
|
||||
Args:
|
||||
text (str): Input text that may be wrapped in Markdown code blocks
|
||||
|
||||
Returns:
|
||||
str: Cleaned text with Markdown code block syntax removed, and stripped of surrounding whitespace
|
||||
|
||||
"""
|
||||
# Remove opening ```markdown tag with optional whitespace and newlines
|
||||
# Matches: optional whitespace + ```markdown + optional whitespace + optional newline
|
||||
text = re.sub(r'^\s*```markdown\s*\n?', '', text)
|
||||
|
||||
# Remove closing ``` tag with optional whitespace and newlines
|
||||
# Matches: optional newline + optional whitespace + ``` + optional whitespace at end
|
||||
text = re.sub(r'\n?\s*```\s*$', '', text)
|
||||
|
||||
# Return text with surrounding whitespace removed
|
||||
return text.strip()
|
||||
59
api/app/core/rag/common/token_utils.py
Normal file
59
api/app/core/rag/common/token_utils.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import os
|
||||
import tiktoken
|
||||
|
||||
from .file_utils import get_project_base_directory
|
||||
|
||||
tiktoken_cache_dir = os.path.join(get_project_base_directory(), "res")
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
|
||||
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
encoder = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
"""Returns the number of tokens in a text string."""
|
||||
try:
|
||||
code_list = encoder.encode(string)
|
||||
return len(code_list)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def total_token_count_from_response(resp):
|
||||
if resp is None:
|
||||
return 0
|
||||
|
||||
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
|
||||
try:
|
||||
return resp.usage_metadata.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'usage' in resp and 'total_tokens' in resp['usage']:
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'usage' in resp and 'input_tokens' in resp['usage'] and 'output_tokens' in resp['usage']:
|
||||
try:
|
||||
return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']:
|
||||
try:
|
||||
return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
def truncate(string: str, max_len: int) -> str:
|
||||
"""Returns truncated text if the length of text exceed max_len."""
|
||||
return encoder.decode(encoder.encode(string)[:max_len])
|
||||
|
||||
122
api/app/core/rag/deepdoc/README.md
Normal file
122
api/app/core/rag/deepdoc/README.md
Normal file
@@ -0,0 +1,122 @@
|
||||
English | [简体中文](./README_zh.md)
|
||||
|
||||
# *Deep*Doc
|
||||
|
||||
- [1. Introduction](#1)
|
||||
- [2. Vision](#2)
|
||||
- [3. Parser](#3)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. Introduction
|
||||
|
||||
With a bunch of documents from various domains with various formats and along with diverse retrieval requirements,
|
||||
an accurate analysis becomes a very challenge task. *Deep*Doc is born for that purpose.
|
||||
There are 2 parts in *Deep*Doc so far: vision and parser.
|
||||
You can run the flowing test programs if you're interested in our results of OCR, layout recognition and TSR.
|
||||
```bash
|
||||
python deepdoc/vision/t_ocr.py -h
|
||||
usage: t_ocr.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR]
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
|
||||
--output_dir OUTPUT_DIR
|
||||
Directory where to store the output images. Default: './ocr_outputs'
|
||||
```
|
||||
```bash
|
||||
python deepdoc/vision/t_recognizer.py -h
|
||||
usage: t_recognizer.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] [--threshold THRESHOLD] [--mode {layout,tsr}]
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
|
||||
--output_dir OUTPUT_DIR
|
||||
Directory where to store the output images. Default: './layouts_outputs'
|
||||
--threshold THRESHOLD
|
||||
A threshold to filter out detections. Default: 0.5
|
||||
--mode {layout,tsr} Task mode: layout recognition or table structure recognition
|
||||
```
|
||||
|
||||
Our models are served on HuggingFace. If you have trouble downloading HuggingFace models, this might help!!
|
||||
```bash
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
```
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Vision
|
||||
|
||||
We use vision information to resolve problems as human being.
|
||||
- OCR. Since a lot of documents presented as images or at least be able to transform to image,
|
||||
OCR is a very essential and fundamental or even universal solution for text extraction.
|
||||
```bash
|
||||
python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
|
||||
```
|
||||
The inputs could be directory to images or PDF, or a image or PDF.
|
||||
You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results,
|
||||
txt files which contain the OCR text.
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/f25bee3d-aaf7-4102-baf5-d5208361d110" width="900"/>
|
||||
</div>
|
||||
|
||||
- Layout recognition. Documents from different domain may have various layouts,
|
||||
like, newspaper, magazine, book and résumé are distinct in terms of layout.
|
||||
Only when machine have an accurate layout analysis, it can decide if these text parts are successive or not,
|
||||
or this part needs Table Structure Recognition(TSR) to process, or this part is a figure and described with this caption.
|
||||
We have 10 basic layout components which covers most cases:
|
||||
- Text
|
||||
- Title
|
||||
- Figure
|
||||
- Figure caption
|
||||
- Table
|
||||
- Table caption
|
||||
- Header
|
||||
- Footer
|
||||
- Reference
|
||||
- Equation
|
||||
|
||||
Have a try on the following command to see the layout detection results.
|
||||
```bash
|
||||
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
|
||||
```
|
||||
The inputs could be directory to images or PDF, or a image or PDF.
|
||||
You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following:
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>
|
||||
</div>
|
||||
|
||||
- Table Structure Recognition(TSR). Data table is a frequently used structure to present data including numbers or text.
|
||||
And the structure of a table might be very complex, like hierarchy headers, spanning cells and projected row headers.
|
||||
Along with TSR, we also reassemble the content into sentences which could be well comprehended by LLM.
|
||||
We have five labels for TSR task:
|
||||
- Column
|
||||
- Row
|
||||
- Column header
|
||||
- Projected row header
|
||||
- Spanning cell
|
||||
|
||||
Have a try on the following command to see the layout detection results.
|
||||
```bash
|
||||
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=tsr --output_dir=path_to_store_result
|
||||
```
|
||||
The inputs could be directory to images or PDF, or a image or PDF.
|
||||
You can look into the folder 'path_to_store_result' where has both images and html pages which demonstrate the detection results as following:
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/cb24e81b-f2ba-49f3-ac09-883d75606f4c" width="1000"/>
|
||||
</div>
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. Parser
|
||||
|
||||
Four kinds of document formats as PDF, DOCX, EXCEL and PPT have their corresponding parser.
|
||||
The most complex one is PDF parser since PDF's flexibility. The output of PDF parser includes:
|
||||
- Text chunks with their own positions in PDF(page number and rectangular positions).
|
||||
- Tables with cropped image from the PDF, and contents which has already translated into natural language sentences.
|
||||
- Figures with caption and text in the figures.
|
||||
|
||||
### Résumé
|
||||
|
||||
The résumé is a very complicated kind of document. A résumé which is composed of unstructured text
|
||||
with various layouts could be resolved into structured data composed of nearly a hundred of fields.
|
||||
We haven't opened the parser yet, as we open the processing method after parsing procedure.
|
||||
|
||||
|
||||
116
api/app/core/rag/deepdoc/README_zh.md
Normal file
116
api/app/core/rag/deepdoc/README_zh.md
Normal file
@@ -0,0 +1,116 @@
|
||||
[English](./README.md) | 简体中文
|
||||
|
||||
# *Deep*Doc
|
||||
|
||||
- [*Deep*Doc](#deepdoc)
|
||||
- [1. 介绍](#1-介绍)
|
||||
- [2. 视觉处理](#2-视觉处理)
|
||||
- [3. 解析器](#3-解析器)
|
||||
- [简历](#简历)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. 介绍
|
||||
|
||||
对于来自不同领域、具有不同格式和不同检索要求的大量文档,准确的分析成为一项极具挑战性的任务。*Deep*Doc 就是为了这个目的而诞生的。到目前为止,*Deep*Doc 中有两个组成部分:视觉处理和解析器。如果您对我们的OCR、布局识别和TSR结果感兴趣,您可以运行下面的测试程序。
|
||||
|
||||
```bash
|
||||
python deepdoc/vision/t_ocr.py -h
|
||||
usage: t_ocr.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR]
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
|
||||
--output_dir OUTPUT_DIR
|
||||
Directory where to store the output images. Default: './ocr_outputs'
|
||||
```
|
||||
|
||||
```bash
|
||||
python deepdoc/vision/t_recognizer.py -h
|
||||
usage: t_recognizer.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] [--threshold THRESHOLD] [--mode {layout,tsr}]
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
|
||||
--output_dir OUTPUT_DIR
|
||||
Directory where to store the output images. Default: './layouts_outputs'
|
||||
--threshold THRESHOLD
|
||||
A threshold to filter out detections. Default: 0.5
|
||||
--mode {layout,tsr} Task mode: layout recognition or table structure recognition
|
||||
```
|
||||
|
||||
HuggingFace为我们的模型提供服务。如果你在下载HuggingFace模型时遇到问题,这可能会有所帮助!!
|
||||
|
||||
```bash
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
```
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. 视觉处理
|
||||
|
||||
作为人类,我们使用视觉信息来解决问题。
|
||||
|
||||
- **OCR(Optical Character Recognition,光学字符识别)**。由于许多文档都是以图像形式呈现的,或者至少能够转换为图像,因此OCR是文本提取的一个非常重要、基本,甚至通用的解决方案。
|
||||
|
||||
```bash
|
||||
python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
|
||||
```
|
||||
|
||||
输入可以是图像或PDF的目录,或者单个图像、PDF文件。您可以查看文件夹 `path_to_store_result` ,其中有演示结果位置的图像,以及包含OCR文本的txt文件。
|
||||
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/f25bee3d-aaf7-4102-baf5-d5208361d110" width="900"/>
|
||||
</div>
|
||||
|
||||
- 布局识别(Layout recognition)。来自不同领域的文件可能有不同的布局,如报纸、杂志、书籍和简历在布局方面是不同的。只有当机器有准确的布局分析时,它才能决定这些文本部分是连续的还是不连续的,或者这个部分需要表结构识别(Table Structure Recognition,TSR)来处理,或者这个部件是一个图形并用这个标题来描述。我们有10个基本布局组件,涵盖了大多数情况:
|
||||
- 文本
|
||||
- 标题
|
||||
- 配图
|
||||
- 配图标题
|
||||
- 表格
|
||||
- 表格标题
|
||||
- 页头
|
||||
- 页尾
|
||||
- 参考引用
|
||||
- 公式
|
||||
|
||||
请尝试以下命令以查看布局检测结果。
|
||||
|
||||
```bash
|
||||
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
|
||||
```
|
||||
|
||||
输入可以是图像或PDF的目录,或者单个图像、PDF文件。您可以查看文件夹 `path_to_store_result` ,其中有显示检测结果的图像,如下所示:
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>
|
||||
</div>
|
||||
|
||||
- **TSR(Table Structure Recognition,表结构识别)**。数据表是一种常用的结构,用于表示包括数字或文本在内的数据。表的结构可能非常复杂,比如层次结构标题、跨单元格和投影行标题。除了TSR,我们还将内容重新组合成LLM可以很好理解的句子。TSR任务有五个标签:
|
||||
- 列
|
||||
- 行
|
||||
- 列标题
|
||||
- 行标题
|
||||
- 合并单元格
|
||||
|
||||
请尝试以下命令以查看布局检测结果。
|
||||
|
||||
```bash
|
||||
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=tsr --output_dir=path_to_store_result
|
||||
```
|
||||
|
||||
输入可以是图像或PDF的目录,或者单个图像、PDF文件。您可以查看文件夹 `path_to_store_result` ,其中包含图像和html页面,这些页面展示了以下检测结果:
|
||||
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/cb24e81b-f2ba-49f3-ac09-883d75606f4c" width="1000"/>
|
||||
</div>
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. 解析器
|
||||
|
||||
PDF、DOCX、EXCEL和PPT四种文档格式都有相应的解析器。最复杂的是PDF解析器,因为PDF具有灵活性。PDF解析器的输出包括:
|
||||
- 在PDF中有自己位置的文本块(页码和矩形位置)。
|
||||
- 带有PDF裁剪图像的表格,以及已经翻译成自然语言句子的内容。
|
||||
- 图中带标题和文字的图。
|
||||
|
||||
### 简历
|
||||
|
||||
简历是一种非常复杂的文档。由各种格式的非结构化文本构成的简历可以被解析为包含近百个字段的结构化数据。我们还没有启用解析器,因为在解析过程之后才会启动处理方法。
|
||||
2
api/app/core/rag/deepdoc/__init__.py
Normal file
2
api/app/core/rag/deepdoc/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from beartype.claw import beartype_this_package
|
||||
beartype_this_package()
|
||||
24
api/app/core/rag/deepdoc/parser/__init__.py
Normal file
24
api/app/core/rag/deepdoc/parser/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from .docx_parser import RAGDocxParser as DocxParser
|
||||
from .excel_parser import RAGExcelParser as ExcelParser
|
||||
from .html_parser import RAGHtmlParser as HtmlParser
|
||||
from .json_parser import RAGJsonParser as JsonParser
|
||||
from .markdown_parser import MarkdownElementExtractor
|
||||
from .markdown_parser import RAGMarkdownParser as MarkdownParser
|
||||
from .pdf_parser import PlainParser
|
||||
from .pdf_parser import RAGPdfParser as PdfParser
|
||||
from .ppt_parser import RAGPptParser as PptParser
|
||||
from .txt_parser import RAGTxtParser as TxtParser
|
||||
|
||||
__all__ = [
|
||||
"PdfParser",
|
||||
"PlainParser",
|
||||
"DocxParser",
|
||||
"ExcelParser",
|
||||
"PptParser",
|
||||
"HtmlParser",
|
||||
"JsonParser",
|
||||
"MarkdownParser",
|
||||
"TxtParser",
|
||||
"MarkdownElementExtractor",
|
||||
]
|
||||
|
||||
123
api/app/core/rag/deepdoc/parser/docx_parser.py
Normal file
123
api/app/core/rag/deepdoc/parser/docx_parser.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from docx import Document
|
||||
import re
|
||||
import pandas as pd
|
||||
from collections import Counter
|
||||
from app.core.rag.nlp import rag_tokenizer
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class RAGDocxParser:
|
||||
|
||||
def __extract_table_content(self, tb):
|
||||
df = []
|
||||
for row in tb.rows:
|
||||
df.append([c.text for c in row.cells])
|
||||
return self.__compose_table_content(pd.DataFrame(df))
|
||||
|
||||
def __compose_table_content(self, df):
|
||||
|
||||
def blockType(b):
|
||||
pattern = [
|
||||
("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}年$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}[年/-][0-9]{1,2}月*$", "Dt"),
|
||||
("^[0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
|
||||
(r"^第*[一二三四1-4]季度$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}[ABCDE]$", "DT"),
|
||||
("^[0-9.,+%/ -]+$", "Nu"),
|
||||
(r"^[0-9A-Z/\._~-]+$", "Ca"),
|
||||
(r"^[A-Z]*[a-z' -]+$", "En"),
|
||||
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
|
||||
(r"^.{1}$", "Sg")
|
||||
]
|
||||
for p, n in pattern:
|
||||
if re.search(p, b):
|
||||
return n
|
||||
tks = [t for t in rag_tokenizer.tokenize(b).split() if len(t) > 1]
|
||||
if len(tks) > 3:
|
||||
if len(tks) < 12:
|
||||
return "Tx"
|
||||
else:
|
||||
return "Lx"
|
||||
|
||||
if len(tks) == 1 and rag_tokenizer.tag(tks[0]) == "nr":
|
||||
return "Nr"
|
||||
|
||||
return "Ot"
|
||||
|
||||
if len(df) < 2:
|
||||
return []
|
||||
max_type = Counter([blockType(str(df.iloc[i, j])) for i in range(
|
||||
1, len(df)) for j in range(len(df.iloc[i, :]))])
|
||||
max_type = max(max_type.items(), key=lambda x: x[1])[0]
|
||||
|
||||
colnm = len(df.iloc[0, :])
|
||||
hdrows = [0] # header is not necessarily appear in the first line
|
||||
if max_type == "Nu":
|
||||
for r in range(1, len(df)):
|
||||
tys = Counter([blockType(str(df.iloc[r, j]))
|
||||
for j in range(len(df.iloc[r, :]))])
|
||||
tys = max(tys.items(), key=lambda x: x[1])[0]
|
||||
if tys != max_type:
|
||||
hdrows.append(r)
|
||||
|
||||
lines = []
|
||||
for i in range(1, len(df)):
|
||||
if i in hdrows:
|
||||
continue
|
||||
hr = [r - i for r in hdrows]
|
||||
hr = [r for r in hr if r < 0]
|
||||
t = len(hr) - 1
|
||||
while t > 0:
|
||||
if hr[t] - hr[t - 1] > 1:
|
||||
hr = hr[t:]
|
||||
break
|
||||
t -= 1
|
||||
headers = []
|
||||
for j in range(len(df.iloc[i, :])):
|
||||
t = []
|
||||
for h in hr:
|
||||
x = str(df.iloc[i + h, j]).strip()
|
||||
if x in t:
|
||||
continue
|
||||
t.append(x)
|
||||
t = ",".join(t)
|
||||
if t:
|
||||
t += ": "
|
||||
headers.append(t)
|
||||
cells = []
|
||||
for j in range(len(df.iloc[i, :])):
|
||||
if not str(df.iloc[i, j]):
|
||||
continue
|
||||
cells.append(headers[j] + str(df.iloc[i, j]))
|
||||
lines.append(";".join(cells))
|
||||
|
||||
if colnm > 3:
|
||||
return lines
|
||||
return ["\n".join(lines)]
|
||||
|
||||
def __call__(self, fnm, from_page=0, to_page=100000000):
|
||||
self.doc = Document(fnm) if isinstance(
|
||||
fnm, str) else Document(BytesIO(fnm))
|
||||
pn = 0 # parsed page
|
||||
secs = [] # parsed contents
|
||||
for p in self.doc.paragraphs:
|
||||
if pn > to_page:
|
||||
break
|
||||
|
||||
runs_within_single_paragraph = [] # save runs within the range of pages
|
||||
for run in p.runs:
|
||||
if pn > to_page:
|
||||
break
|
||||
if from_page <= pn < to_page and p.text.strip():
|
||||
runs_within_single_paragraph.append(run.text) # append run.text first
|
||||
|
||||
# wrap page break checker into a static method
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
|
||||
secs.append(("".join(runs_within_single_paragraph), p.style.name if hasattr(p.style, 'name') else '')) # then concat run.text as part of the paragraph
|
||||
|
||||
tbls = [self.__extract_table_content(tb) for tb in self.doc.tables]
|
||||
return secs, tbls
|
||||
210
api/app/core/rag/deepdoc/parser/excel_parser.py
Normal file
210
api/app/core/rag/deepdoc/parser/excel_parser.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from io import BytesIO
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import Workbook, load_workbook
|
||||
|
||||
from app.core.rag.nlp import find_codec
|
||||
|
||||
# copied from `/openpyxl/cell/cell.py`
|
||||
ILLEGAL_CHARACTERS_RE = re.compile(r"[\000-\010]|[\013-\014]|[\016-\037]")
|
||||
|
||||
|
||||
class RAGExcelParser:
|
||||
@staticmethod
|
||||
def _load_excel_to_workbook(file_like_object):
|
||||
if isinstance(file_like_object, bytes):
|
||||
file_like_object = BytesIO(file_like_object)
|
||||
|
||||
# Read first 4 bytes to determine file type
|
||||
file_like_object.seek(0)
|
||||
file_head = file_like_object.read(4)
|
||||
file_like_object.seek(0)
|
||||
|
||||
if not (file_head.startswith(b"PK\x03\x04") or file_head.startswith(b"\xd0\xcf\x11\xe0")):
|
||||
logging.info("Not an Excel file, converting CSV to Excel Workbook")
|
||||
|
||||
try:
|
||||
file_like_object.seek(0)
|
||||
df = pd.read_csv(file_like_object)
|
||||
return RAGExcelParser._dataframe_to_workbook(df)
|
||||
|
||||
except Exception as e_csv:
|
||||
raise Exception(f"Failed to parse CSV and convert to Excel Workbook: {e_csv}")
|
||||
|
||||
try:
|
||||
return load_workbook(file_like_object, data_only=True)
|
||||
except Exception as e:
|
||||
logging.info(f"openpyxl load error: {e}, try pandas instead")
|
||||
try:
|
||||
file_like_object.seek(0)
|
||||
try:
|
||||
dfs = pd.read_excel(file_like_object, sheet_name=None)
|
||||
return RAGExcelParser._dataframe_to_workbook(dfs)
|
||||
except Exception as ex:
|
||||
logging.info(f"pandas with default engine load error: {ex}, try calamine instead")
|
||||
file_like_object.seek(0)
|
||||
df = pd.read_excel(file_like_object, engine="calamine")
|
||||
return RAGExcelParser._dataframe_to_workbook(df)
|
||||
except Exception as e_pandas:
|
||||
raise Exception(f"pandas.read_excel error: {e_pandas}, original openpyxl error: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _clean_dataframe(df: pd.DataFrame):
|
||||
def clean_string(s):
|
||||
if isinstance(s, str):
|
||||
return ILLEGAL_CHARACTERS_RE.sub(" ", s)
|
||||
return s
|
||||
|
||||
return df.apply(lambda col: col.map(clean_string))
|
||||
|
||||
@staticmethod
|
||||
def _dataframe_to_workbook(df):
|
||||
# if contains multiple sheets use _dataframes_to_workbook
|
||||
if isinstance(df, dict) and len(df) > 1:
|
||||
return RAGExcelParser._dataframes_to_workbook(df)
|
||||
|
||||
df = RAGExcelParser._clean_dataframe(df)
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
ws.title = "Data"
|
||||
|
||||
for col_num, column_name in enumerate(df.columns, 1):
|
||||
ws.cell(row=1, column=col_num, value=column_name)
|
||||
|
||||
for row_num, row in enumerate(df.values, 2):
|
||||
for col_num, value in enumerate(row, 1):
|
||||
ws.cell(row=row_num, column=col_num, value=value)
|
||||
|
||||
return wb
|
||||
|
||||
@staticmethod
|
||||
def _dataframes_to_workbook(dfs: dict):
|
||||
wb = Workbook()
|
||||
default_sheet = wb.active
|
||||
wb.remove(default_sheet)
|
||||
|
||||
for sheet_name, df in dfs.items():
|
||||
df = RAGExcelParser._clean_dataframe(df)
|
||||
ws = wb.create_sheet(title=sheet_name)
|
||||
for col_num, column_name in enumerate(df.columns, 1):
|
||||
ws.cell(row=1, column=col_num, value=column_name)
|
||||
for row_num, row in enumerate(df.values, 2):
|
||||
for col_num, value in enumerate(row, 1):
|
||||
ws.cell(row=row_num, column=col_num, value=value)
|
||||
return wb
|
||||
|
||||
def html(self, fnm, chunk_rows=256):
|
||||
from html import escape
|
||||
|
||||
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
|
||||
wb = RAGExcelParser._load_excel_to_workbook(file_like_object)
|
||||
tb_chunks = []
|
||||
|
||||
def _fmt(v):
|
||||
if v is None:
|
||||
return ""
|
||||
return str(v).strip()
|
||||
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
try:
|
||||
rows = list(ws.rows)
|
||||
except Exception as e:
|
||||
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
|
||||
continue
|
||||
|
||||
if not rows:
|
||||
continue
|
||||
|
||||
tb_rows_0 = "<tr>"
|
||||
for t in list(rows[0]):
|
||||
tb_rows_0 += f"<th>{escape(_fmt(t.value))}</th>"
|
||||
tb_rows_0 += "</tr>"
|
||||
|
||||
for chunk_i in range((len(rows) - 1) // chunk_rows + 1):
|
||||
tb = ""
|
||||
tb += f"<table><caption>{sheetname}</caption>"
|
||||
tb += tb_rows_0
|
||||
for r in list(rows[1 + chunk_i * chunk_rows : min(1 + (chunk_i + 1) * chunk_rows, len(rows))]):
|
||||
tb += "<tr>"
|
||||
for i, c in enumerate(r):
|
||||
if c.value is None:
|
||||
tb += "<td></td>"
|
||||
else:
|
||||
tb += f"<td>{escape(_fmt(c.value))}</td>"
|
||||
tb += "</tr>"
|
||||
tb += "</table>\n"
|
||||
tb_chunks.append(tb)
|
||||
|
||||
return tb_chunks
|
||||
|
||||
def markdown(self, fnm):
|
||||
import pandas as pd
|
||||
|
||||
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
|
||||
try:
|
||||
file_like_object.seek(0)
|
||||
df = pd.read_excel(file_like_object)
|
||||
except Exception as e:
|
||||
logging.warning(f"Parse spreadsheet error: {e}, trying to interpret as CSV file")
|
||||
file_like_object.seek(0)
|
||||
df = pd.read_csv(file_like_object)
|
||||
df = df.replace(r"^\s*$", "", regex=True)
|
||||
return df.to_markdown(index=False)
|
||||
|
||||
def __call__(self, fnm):
|
||||
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
|
||||
wb = RAGExcelParser._load_excel_to_workbook(file_like_object)
|
||||
|
||||
res = []
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
try:
|
||||
rows = list(ws.rows)
|
||||
except Exception as e:
|
||||
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
|
||||
continue
|
||||
if not rows:
|
||||
continue
|
||||
ti = list(rows[0])
|
||||
for r in list(rows[1:]):
|
||||
fields = []
|
||||
for i, c in enumerate(r):
|
||||
if not c.value:
|
||||
continue
|
||||
t = str(ti[i].value) if i < len(ti) else ""
|
||||
t += (":" if t else "") + str(c.value)
|
||||
fields.append(t)
|
||||
line = "; ".join(fields)
|
||||
if sheetname.lower().find("sheet") < 0:
|
||||
line += " ——" + sheetname
|
||||
res.append(line)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def row_number(fnm, binary):
|
||||
if fnm.split(".")[-1].lower().find("xls") >= 0:
|
||||
wb = RAGExcelParser._load_excel_to_workbook(BytesIO(binary))
|
||||
total = 0
|
||||
|
||||
for sheetname in wb.sheetnames:
|
||||
try:
|
||||
ws = wb[sheetname]
|
||||
total += len(list(ws.rows))
|
||||
except Exception as e:
|
||||
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
|
||||
continue
|
||||
return total
|
||||
|
||||
if fnm.split(".")[-1].lower() in ["csv", "txt"]:
|
||||
encoding = find_codec(binary)
|
||||
txt = binary.decode(encoding, errors="ignore")
|
||||
return len(txt.split("\n"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
psr = RAGExcelParser()
|
||||
psr(sys.argv[1])
|
||||
118
api/app/core/rag/deepdoc/parser/figure_parser.py
Normal file
118
api/app/core/rag/deepdoc/parser/figure_parser.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from app.core.rag.common.constants import LLMType
|
||||
from app.core.rag.common.connection_utils import timeout
|
||||
from app.core.rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
||||
from app.core.rag.prompts.generator import vision_llm_figure_describe_prompt
|
||||
|
||||
|
||||
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
|
||||
return [
|
||||
(
|
||||
(figure_data[1], [figure_data[0]]),
|
||||
[(0, 0, 0, 0, 0)],
|
||||
)
|
||||
for figure_data in figures_data_without_positions
|
||||
if isinstance(figure_data[1], Image.Image)
|
||||
]
|
||||
|
||||
def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,vision_model=None,**kwargs):
|
||||
if vision_model:
|
||||
figures_data = vision_figure_parser_figure_data_wrapper(sections)
|
||||
try:
|
||||
docx_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs)
|
||||
boosted_figures = docx_vision_parser(callback=callback)
|
||||
tbls.extend(boosted_figures)
|
||||
except Exception as e:
|
||||
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
|
||||
return tbls
|
||||
|
||||
def vision_figure_parser_pdf_wrapper(tbls,callback=None,vision_model=None,**kwargs):
|
||||
if vision_model:
|
||||
def is_figure_item(item):
|
||||
return (
|
||||
isinstance(item[0][0], Image.Image) and
|
||||
isinstance(item[0][1], list)
|
||||
)
|
||||
figures_data = [item for item in tbls if is_figure_item(item)]
|
||||
try:
|
||||
docx_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs)
|
||||
boosted_figures = docx_vision_parser(callback=callback)
|
||||
tbls = [item for item in tbls if not is_figure_item(item)]
|
||||
tbls.extend(boosted_figures)
|
||||
except Exception as e:
|
||||
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
|
||||
return tbls
|
||||
|
||||
shared_executor = ThreadPoolExecutor(max_workers=10)
|
||||
|
||||
|
||||
class VisionFigureParser:
|
||||
def __init__(self, vision_model, figures_data, *args, **kwargs):
|
||||
self.vision_model = vision_model
|
||||
self._extract_figures_info(figures_data)
|
||||
assert len(self.figures) == len(self.descriptions)
|
||||
assert not self.positions or (len(self.figures) == len(self.positions))
|
||||
|
||||
def _extract_figures_info(self, figures_data):
|
||||
self.figures = []
|
||||
self.descriptions = []
|
||||
self.positions = []
|
||||
|
||||
for item in figures_data:
|
||||
# position
|
||||
if len(item) == 2 and isinstance(item[0], tuple) and len(item[0]) == 2 and isinstance(item[1], list) and isinstance(item[1][0], tuple) and len(item[1][0]) == 5:
|
||||
img_desc = item[0]
|
||||
assert len(img_desc) == 2 and isinstance(img_desc[0], Image.Image) and isinstance(img_desc[1], list), "Should be (figure, [description])"
|
||||
self.figures.append(img_desc[0])
|
||||
self.descriptions.append(img_desc[1])
|
||||
self.positions.append(item[1])
|
||||
else:
|
||||
assert len(item) == 2 and isinstance(item[0], Image.Image) and isinstance(item[1], list), f"Unexpected form of figure data: get {len(item)=}, {item=}"
|
||||
self.figures.append(item[0])
|
||||
self.descriptions.append(item[1])
|
||||
|
||||
def _assemble(self):
|
||||
self.assembled = []
|
||||
self.has_positions = len(self.positions) != 0
|
||||
for i in range(len(self.figures)):
|
||||
figure = self.figures[i]
|
||||
desc = self.descriptions[i]
|
||||
pos = self.positions[i] if self.has_positions else None
|
||||
|
||||
figure_desc = (figure, desc)
|
||||
|
||||
if pos is not None:
|
||||
self.assembled.append((figure_desc, pos))
|
||||
else:
|
||||
self.assembled.append((figure_desc,))
|
||||
|
||||
return self.assembled
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
callback = kwargs.get("callback", lambda prog, msg: None)
|
||||
|
||||
@timeout(30, 3)
|
||||
def process(figure_idx, figure_binary):
|
||||
description_text = picture_vision_llm_chunk(
|
||||
binary=figure_binary,
|
||||
vision_model=self.vision_model,
|
||||
prompt=vision_llm_figure_describe_prompt(),
|
||||
callback=callback,
|
||||
)
|
||||
return figure_idx, description_text
|
||||
|
||||
futures = []
|
||||
for idx, img_binary in enumerate(self.figures or []):
|
||||
futures.append(shared_executor.submit(process, idx, img_binary))
|
||||
|
||||
for future in as_completed(futures):
|
||||
figure_num, txt = future.result()
|
||||
if txt:
|
||||
self.descriptions[figure_num] = txt + "\n".join(self.descriptions[figure_num])
|
||||
|
||||
self._assemble()
|
||||
|
||||
return self.assembled
|
||||
197
api/app/core/rag/deepdoc/parser/html_parser.py
Normal file
197
api/app/core/rag/deepdoc/parser/html_parser.py
Normal file
@@ -0,0 +1,197 @@
|
||||
from app.core.rag.nlp import find_codec, rag_tokenizer
|
||||
import uuid
|
||||
import chardet
|
||||
from bs4 import BeautifulSoup, NavigableString, Tag, Comment
|
||||
import html
|
||||
|
||||
def get_encoding(file):
|
||||
with open(file,'rb') as f:
|
||||
tmp = chardet.detect(f.read())
|
||||
return tmp['encoding']
|
||||
|
||||
BLOCK_TAGS = [
|
||||
"h1", "h2", "h3", "h4", "h5", "h6",
|
||||
"p", "div", "article", "section", "aside",
|
||||
"ul", "ol", "li",
|
||||
"table", "pre", "code", "blockquote",
|
||||
"figure", "figcaption"
|
||||
]
|
||||
TITLE_TAGS = {"h1": "#", "h2": "##", "h3": "###", "h4": "#####", "h5": "#####", "h6": "######"}
|
||||
|
||||
|
||||
class RAGHtmlParser:
|
||||
def __call__(self, fnm, binary=None, chunk_token_num=512):
|
||||
if binary:
|
||||
encoding = find_codec(binary)
|
||||
txt = binary.decode(encoding, errors="ignore")
|
||||
else:
|
||||
with open(fnm, "r",encoding=get_encoding(fnm)) as f:
|
||||
txt = f.read()
|
||||
return self.parser_txt(txt, chunk_token_num)
|
||||
|
||||
@classmethod
|
||||
def parser_txt(cls, txt, chunk_token_num):
|
||||
if not isinstance(txt, str):
|
||||
raise TypeError("txt type should be string!")
|
||||
|
||||
temp_sections = []
|
||||
soup = BeautifulSoup(txt, "html5lib")
|
||||
# delete <style> tag
|
||||
for style_tag in soup.find_all(["style", "script"]):
|
||||
style_tag.decompose()
|
||||
# delete <script> tag in <div>
|
||||
for div_tag in soup.find_all("div"):
|
||||
for script_tag in div_tag.find_all("script"):
|
||||
script_tag.decompose()
|
||||
# delete inline style
|
||||
for tag in soup.find_all(True):
|
||||
if 'style' in tag.attrs:
|
||||
del tag.attrs['style']
|
||||
# delete HTML comment
|
||||
for comment in soup.find_all(string=lambda text: isinstance(text, Comment)):
|
||||
comment.extract()
|
||||
|
||||
cls.read_text_recursively(soup.body, temp_sections, chunk_token_num=chunk_token_num)
|
||||
block_txt_list, table_list = cls.merge_block_text(temp_sections)
|
||||
sections = cls.chunk_block(block_txt_list, chunk_token_num=chunk_token_num)
|
||||
for table in table_list:
|
||||
sections.append(table.get("content", ""))
|
||||
return sections
|
||||
|
||||
@classmethod
|
||||
def split_table(cls, html_table, chunk_token_num=512):
|
||||
soup = BeautifulSoup(html_table, "html.parser")
|
||||
rows = soup.find_all("tr")
|
||||
tables = []
|
||||
current_table = []
|
||||
current_count = 0
|
||||
table_str_list = []
|
||||
for row in rows:
|
||||
tks_str = rag_tokenizer.tokenize(str(row))
|
||||
token_count = len(tks_str.split(" ")) if tks_str else 0
|
||||
if current_count + token_count > chunk_token_num:
|
||||
tables.append(current_table)
|
||||
current_table = []
|
||||
current_count = 0
|
||||
current_table.append(row)
|
||||
current_count += token_count
|
||||
if current_table:
|
||||
tables.append(current_table)
|
||||
|
||||
for table_rows in tables:
|
||||
new_table = soup.new_tag("table")
|
||||
for row in table_rows:
|
||||
new_table.append(row)
|
||||
table_str_list.append(str(new_table))
|
||||
|
||||
return table_str_list
|
||||
|
||||
@classmethod
|
||||
def read_text_recursively(cls, element, parser_result, chunk_token_num=512, parent_name=None, block_id=None):
|
||||
if isinstance(element, NavigableString):
|
||||
content = element.strip()
|
||||
|
||||
def is_valid_html(content):
|
||||
try:
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
return bool(soup.find())
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return_info = []
|
||||
if content:
|
||||
if is_valid_html(content):
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
child_info = cls.read_text_recursively(soup, parser_result, chunk_token_num, element.name, block_id)
|
||||
parser_result.extend(child_info)
|
||||
else:
|
||||
info = {"content": element.strip(), "tag_name": "inner_text", "metadata": {"block_id": block_id}}
|
||||
if parent_name:
|
||||
info["tag_name"] = parent_name
|
||||
return_info.append(info)
|
||||
return return_info
|
||||
elif isinstance(element, Tag):
|
||||
|
||||
if str.lower(element.name) == "table":
|
||||
table_info_list = []
|
||||
table_id = str(uuid.uuid1())
|
||||
table_list = [html.unescape(str(element))]
|
||||
for t in table_list:
|
||||
table_info_list.append({"content": t, "tag_name": "table",
|
||||
"metadata": {"table_id": table_id, "index": table_list.index(t)}})
|
||||
return table_info_list
|
||||
else:
|
||||
block_id = None
|
||||
if str.lower(element.name) in BLOCK_TAGS:
|
||||
block_id = str(uuid.uuid1())
|
||||
for child in element.children:
|
||||
child_info = cls.read_text_recursively(child, parser_result, chunk_token_num, element.name,
|
||||
block_id)
|
||||
parser_result.extend(child_info)
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def merge_block_text(cls, parser_result):
|
||||
block_content = []
|
||||
current_content = ""
|
||||
table_info_list = []
|
||||
lask_block_id = None
|
||||
for item in parser_result:
|
||||
content = item.get("content")
|
||||
tag_name = item.get("tag_name")
|
||||
title_flag = tag_name in TITLE_TAGS
|
||||
block_id = item.get("metadata", {}).get("block_id")
|
||||
if block_id:
|
||||
if title_flag:
|
||||
content = f"{TITLE_TAGS[tag_name]} {content}"
|
||||
if lask_block_id != block_id:
|
||||
if lask_block_id is not None:
|
||||
block_content.append(current_content)
|
||||
current_content = content
|
||||
lask_block_id = block_id
|
||||
else:
|
||||
current_content += (" " if current_content else "") + content
|
||||
else:
|
||||
if tag_name == "table":
|
||||
table_info_list.append(item)
|
||||
else:
|
||||
current_content += (" " if current_content else "" + content)
|
||||
if current_content:
|
||||
block_content.append(current_content)
|
||||
return block_content, table_info_list
|
||||
|
||||
@classmethod
|
||||
def chunk_block(cls, block_txt_list, chunk_token_num=512):
|
||||
chunks = []
|
||||
current_block = ""
|
||||
current_token_count = 0
|
||||
|
||||
for block in block_txt_list:
|
||||
tks_str = rag_tokenizer.tokenize(block)
|
||||
block_token_count = len(tks_str.split(" ")) if tks_str else 0
|
||||
if block_token_count > chunk_token_num:
|
||||
if current_block:
|
||||
chunks.append(current_block)
|
||||
start = 0
|
||||
tokens = tks_str.split(" ")
|
||||
while start < len(tokens):
|
||||
end = start + chunk_token_num
|
||||
split_tokens = tokens[start:end]
|
||||
chunks.append(" ".join(split_tokens))
|
||||
start = end
|
||||
current_block = ""
|
||||
current_token_count = 0
|
||||
else:
|
||||
if current_token_count + block_token_count <= chunk_token_num:
|
||||
current_block += ("\n" if current_block else "") + block
|
||||
current_token_count += block_token_count
|
||||
else:
|
||||
chunks.append(current_block)
|
||||
current_block = block
|
||||
current_token_count = block_token_count
|
||||
|
||||
if current_block:
|
||||
chunks.append(current_block)
|
||||
|
||||
return chunks
|
||||
|
||||
159
api/app/core/rag/deepdoc/parser/json_parser.py
Normal file
159
api/app/core/rag/deepdoc/parser/json_parser.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from app.core.rag.nlp import find_codec
|
||||
|
||||
|
||||
class RAGJsonParser:
|
||||
def __init__(self, max_chunk_size: int = 2000, min_chunk_size: int | None = None):
|
||||
super().__init__()
|
||||
self.max_chunk_size = max_chunk_size * 2
|
||||
self.min_chunk_size = min_chunk_size if min_chunk_size is not None else max(max_chunk_size - 200, 50)
|
||||
|
||||
def __call__(self, filename):
|
||||
with open(filename, "r") as f:
|
||||
txt = f.read()
|
||||
|
||||
if self.is_jsonl_format(txt):
|
||||
sections = self._parse_jsonl(txt)
|
||||
else:
|
||||
sections = self._parse_json(txt)
|
||||
return sections
|
||||
|
||||
@staticmethod
|
||||
def _json_size(data: dict) -> int:
|
||||
"""Calculate the size of the serialized JSON object."""
|
||||
return len(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
@staticmethod
|
||||
def _set_nested_dict(d: dict, path: list[str], value: Any) -> None:
|
||||
"""Set a value in a nested dictionary based on the given path."""
|
||||
for key in path[:-1]:
|
||||
d = d.setdefault(key, {})
|
||||
d[path[-1]] = value
|
||||
|
||||
def _list_to_dict_preprocessing(self, data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
# Process each key-value pair in the dictionary
|
||||
return {k: self._list_to_dict_preprocessing(v) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
# Convert the list to a dictionary with index-based keys
|
||||
return {str(i): self._list_to_dict_preprocessing(item) for i, item in enumerate(data)}
|
||||
else:
|
||||
# Base case: the item is neither a dict nor a list, so return it unchanged
|
||||
return data
|
||||
|
||||
def _json_split(
|
||||
self,
|
||||
data,
|
||||
current_path: list[str] | None,
|
||||
chunks: list[dict] | None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Split json into maximum size dictionaries while preserving structure.
|
||||
"""
|
||||
current_path = current_path or []
|
||||
chunks = chunks or [{}]
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
new_path = current_path + [key]
|
||||
chunk_size = self._json_size(chunks[-1])
|
||||
size = self._json_size({key: value})
|
||||
remaining = self.max_chunk_size - chunk_size
|
||||
|
||||
if size < remaining:
|
||||
# Add item to current chunk
|
||||
self._set_nested_dict(chunks[-1], new_path, value)
|
||||
else:
|
||||
if chunk_size >= self.min_chunk_size:
|
||||
# Chunk is big enough, start a new chunk
|
||||
chunks.append({})
|
||||
|
||||
# Iterate
|
||||
self._json_split(value, new_path, chunks)
|
||||
else:
|
||||
# handle single item
|
||||
self._set_nested_dict(chunks[-1], current_path, data)
|
||||
return chunks
|
||||
|
||||
def split_json(
|
||||
self,
|
||||
json_data,
|
||||
convert_lists: bool = False,
|
||||
) -> list[dict]:
|
||||
"""Splits JSON into a list of JSON chunks"""
|
||||
|
||||
if convert_lists:
|
||||
preprocessed_data = self._list_to_dict_preprocessing(json_data)
|
||||
chunks = self._json_split(preprocessed_data, None, None)
|
||||
else:
|
||||
chunks = self._json_split(json_data, None, None)
|
||||
|
||||
# Remove the last chunk if it's empty
|
||||
if not chunks[-1]:
|
||||
chunks.pop()
|
||||
return chunks
|
||||
|
||||
def split_text(
|
||||
self,
|
||||
json_data: dict[str, Any],
|
||||
convert_lists: bool = False,
|
||||
ensure_ascii: bool = True,
|
||||
) -> list[str]:
|
||||
"""Splits JSON into a list of JSON formatted strings"""
|
||||
|
||||
chunks = self.split_json(json_data=json_data, convert_lists=convert_lists)
|
||||
|
||||
# Convert to string
|
||||
return [json.dumps(chunk, ensure_ascii=ensure_ascii) for chunk in chunks]
|
||||
|
||||
def _parse_json(self, content: str) -> list[str]:
|
||||
sections = []
|
||||
try:
|
||||
json_data = json.loads(content)
|
||||
chunks = self.split_json(json_data, True)
|
||||
sections = [json.dumps(line, ensure_ascii=False) for line in chunks if line]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return sections
|
||||
|
||||
def _parse_jsonl(self, content: str) -> list[str]:
|
||||
lines = content.strip().splitlines()
|
||||
all_chunks = []
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
chunks = self.split_json(data, convert_lists=True)
|
||||
all_chunks.extend(json.dumps(chunk, ensure_ascii=False) for chunk in chunks if chunk)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return all_chunks
|
||||
|
||||
def is_jsonl_format(self, txt: str, sample_limit: int = 10, threshold: float = 0.8) -> bool:
|
||||
lines = [line.strip() for line in txt.strip().splitlines() if line.strip()]
|
||||
if not lines:
|
||||
return False
|
||||
|
||||
try:
|
||||
json.loads(txt)
|
||||
return False
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
sample_limit = min(len(lines), sample_limit)
|
||||
sample_lines = lines[:sample_limit]
|
||||
valid_lines = sum(1 for line in sample_lines if self._is_valid_json(line))
|
||||
|
||||
if not valid_lines:
|
||||
return False
|
||||
|
||||
return (valid_lines / len(sample_lines)) >= threshold
|
||||
|
||||
def _is_valid_json(self, line: str) -> bool:
|
||||
try:
|
||||
json.loads(line)
|
||||
return True
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
277
api/app/core/rag/deepdoc/parser/markdown_parser.py
Normal file
277
api/app/core/rag/deepdoc/parser/markdown_parser.py
Normal file
@@ -0,0 +1,277 @@
|
||||
import re
|
||||
|
||||
from markdown import markdown
|
||||
|
||||
|
||||
class RAGMarkdownParser:
|
||||
def __init__(self, chunk_token_num=128):
|
||||
self.chunk_token_num = int(chunk_token_num)
|
||||
|
||||
def extract_tables_and_remainder(self, markdown_text, separate_tables=True):
|
||||
tables = []
|
||||
working_text = markdown_text
|
||||
|
||||
def replace_tables_with_rendered_html(pattern, table_list, render=True):
|
||||
new_text = ""
|
||||
last_end = 0
|
||||
for match in pattern.finditer(working_text):
|
||||
raw_table = match.group()
|
||||
table_list.append(raw_table)
|
||||
if separate_tables:
|
||||
# Skip this match (i.e., remove it)
|
||||
new_text += working_text[last_end : match.start()] + "\n\n"
|
||||
else:
|
||||
# Replace with rendered HTML
|
||||
html_table = markdown(raw_table, extensions=["markdown.extensions.tables"]) if render else raw_table
|
||||
new_text += working_text[last_end : match.start()] + html_table + "\n\n"
|
||||
last_end = match.end()
|
||||
new_text += working_text[last_end:]
|
||||
return new_text
|
||||
|
||||
if "|" in markdown_text: # for optimize performance
|
||||
# Standard Markdown table
|
||||
border_table_pattern = re.compile(
|
||||
r"""
|
||||
(?:\n|^)
|
||||
(?:\|.*?\|.*?\|.*?\n)
|
||||
(?:\|(?:\s*[:-]+[-| :]*\s*)\|.*?\n)
|
||||
(?:\|.*?\|.*?\|.*?\n)+
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
working_text = replace_tables_with_rendered_html(border_table_pattern, tables)
|
||||
|
||||
# Borderless Markdown table
|
||||
no_border_table_pattern = re.compile(
|
||||
r"""
|
||||
(?:\n|^)
|
||||
(?:\S.*?\|.*?\n)
|
||||
(?:(?:\s*[:-]+[-| :]*\s*).*?\n)
|
||||
(?:\S.*?\|.*?\n)+
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
working_text = replace_tables_with_rendered_html(no_border_table_pattern, tables)
|
||||
|
||||
# Replace any TAGS e.g. <table ...> to <table>
|
||||
TAGS = ["table", "td", "tr", "th", "tbody", "thead", "div"]
|
||||
table_with_attributes_pattern = re.compile(
|
||||
rf"<(?:{'|'.join(TAGS)})[^>]*>", re.IGNORECASE
|
||||
)
|
||||
def replace_tag(m):
|
||||
tag_name = re.match(r"<(\w+)", m.group()).group(1)
|
||||
return "<{}>".format(tag_name)
|
||||
|
||||
working_text = re.sub(table_with_attributes_pattern, replace_tag, working_text)
|
||||
|
||||
if "<table>" in working_text.lower(): # for optimize performance
|
||||
# HTML table extraction - handle possible html/body wrapper tags
|
||||
html_table_pattern = re.compile(
|
||||
r"""
|
||||
(?:\n|^)
|
||||
\s*
|
||||
(?:
|
||||
# case1: <html><body><table>...</table></body></html>
|
||||
(?:<html[^>]*>\s*<body[^>]*>\s*<table[^>]*>.*?</table>\s*</body>\s*</html>)
|
||||
|
|
||||
# case2: <body><table>...</table></body>
|
||||
(?:<body[^>]*>\s*<table[^>]*>.*?</table>\s*</body>)
|
||||
|
|
||||
# case3: only<table>...</table>
|
||||
(?:<table[^>]*>.*?</table>)
|
||||
)
|
||||
\s*
|
||||
(?=\n|$)
|
||||
""",
|
||||
re.VERBOSE | re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
|
||||
def replace_html_tables():
|
||||
nonlocal working_text
|
||||
new_text = ""
|
||||
last_end = 0
|
||||
for match in html_table_pattern.finditer(working_text):
|
||||
raw_table = match.group()
|
||||
tables.append(raw_table)
|
||||
if separate_tables:
|
||||
new_text += working_text[last_end : match.start()] + "\n\n"
|
||||
else:
|
||||
new_text += working_text[last_end : match.start()] + raw_table + "\n\n"
|
||||
last_end = match.end()
|
||||
new_text += working_text[last_end:]
|
||||
working_text = new_text
|
||||
|
||||
replace_html_tables()
|
||||
|
||||
return working_text, tables
|
||||
|
||||
|
||||
class MarkdownElementExtractor:
|
||||
def __init__(self, markdown_content):
|
||||
self.markdown_content = markdown_content
|
||||
self.lines = markdown_content.split("\n")
|
||||
|
||||
def get_delimiters(self,delimiters):
|
||||
toks = re.findall(r"`([^`]+)`", delimiters)
|
||||
toks = sorted(set(toks), key=lambda x: -len(x))
|
||||
return "|".join(re.escape(t) for t in toks if t)
|
||||
|
||||
def extract_elements(self,delimiter=None):
|
||||
"""Extract individual elements (headers, code blocks, lists, etc.)"""
|
||||
sections = []
|
||||
|
||||
i = 0
|
||||
dels=""
|
||||
if delimiter:
|
||||
dels = self.get_delimiters(delimiter)
|
||||
if len(dels) > 0:
|
||||
text = "\n".join(self.lines)
|
||||
parts = re.split(dels, text)
|
||||
sections = [p.strip() for p in parts if p and p.strip()]
|
||||
return sections
|
||||
while i < len(self.lines):
|
||||
line = self.lines[i]
|
||||
|
||||
if re.match(r"^#{1,6}\s+.*$", line):
|
||||
# header
|
||||
element = self._extract_header(i)
|
||||
sections.append(element["content"])
|
||||
i = element["end_line"] + 1
|
||||
elif line.strip().startswith("```"):
|
||||
# code block
|
||||
element = self._extract_code_block(i)
|
||||
sections.append(element["content"])
|
||||
i = element["end_line"] + 1
|
||||
elif re.match(r"^\s*[-*+]\s+.*$", line) or re.match(r"^\s*\d+\.\s+.*$", line):
|
||||
# list block
|
||||
element = self._extract_list_block(i)
|
||||
sections.append(element["content"])
|
||||
i = element["end_line"] + 1
|
||||
elif line.strip().startswith(">"):
|
||||
# blockquote
|
||||
element = self._extract_blockquote(i)
|
||||
sections.append(element["content"])
|
||||
i = element["end_line"] + 1
|
||||
elif line.strip():
|
||||
# text block (paragraphs and inline elements until next block element)
|
||||
element = self._extract_text_block(i)
|
||||
sections.append(element["content"])
|
||||
i = element["end_line"] + 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
sections = [section for section in sections if section.strip()]
|
||||
return sections
|
||||
|
||||
def _extract_header(self, start_pos):
|
||||
return {
|
||||
"type": "header",
|
||||
"content": self.lines[start_pos],
|
||||
"start_line": start_pos,
|
||||
"end_line": start_pos,
|
||||
}
|
||||
|
||||
def _extract_code_block(self, start_pos):
|
||||
end_pos = start_pos
|
||||
content_lines = [self.lines[start_pos]]
|
||||
|
||||
# Find the end of the code block
|
||||
for i in range(start_pos + 1, len(self.lines)):
|
||||
content_lines.append(self.lines[i])
|
||||
end_pos = i
|
||||
if self.lines[i].strip().startswith("```"):
|
||||
break
|
||||
|
||||
return {
|
||||
"type": "code_block",
|
||||
"content": "\n".join(content_lines),
|
||||
"start_line": start_pos,
|
||||
"end_line": end_pos,
|
||||
}
|
||||
|
||||
def _extract_list_block(self, start_pos):
|
||||
end_pos = start_pos
|
||||
content_lines = []
|
||||
|
||||
i = start_pos
|
||||
while i < len(self.lines):
|
||||
line = self.lines[i]
|
||||
# check if this line is a list item or continuation of a list
|
||||
if (
|
||||
re.match(r"^\s*[-*+]\s+.*$", line)
|
||||
or re.match(r"^\s*\d+\.\s+.*$", line)
|
||||
or (i > start_pos and not line.strip())
|
||||
or (i > start_pos and re.match(r"^\s{2,}[-*+]\s+.*$", line))
|
||||
or (i > start_pos and re.match(r"^\s{2,}\d+\.\s+.*$", line))
|
||||
or (i > start_pos and re.match(r"^\s+\w+.*$", line))
|
||||
):
|
||||
content_lines.append(line)
|
||||
end_pos = i
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
return {
|
||||
"type": "list_block",
|
||||
"content": "\n".join(content_lines),
|
||||
"start_line": start_pos,
|
||||
"end_line": end_pos,
|
||||
}
|
||||
|
||||
def _extract_blockquote(self, start_pos):
|
||||
end_pos = start_pos
|
||||
content_lines = []
|
||||
|
||||
i = start_pos
|
||||
while i < len(self.lines):
|
||||
line = self.lines[i]
|
||||
if line.strip().startswith(">") or (i > start_pos and not line.strip()):
|
||||
content_lines.append(line)
|
||||
end_pos = i
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
return {
|
||||
"type": "blockquote",
|
||||
"content": "\n".join(content_lines),
|
||||
"start_line": start_pos,
|
||||
"end_line": end_pos,
|
||||
}
|
||||
|
||||
def _extract_text_block(self, start_pos):
|
||||
"""Extract a text block (paragraphs, inline elements) until next block element"""
|
||||
end_pos = start_pos
|
||||
content_lines = [self.lines[start_pos]]
|
||||
|
||||
i = start_pos + 1
|
||||
while i < len(self.lines):
|
||||
line = self.lines[i]
|
||||
# stop if we encounter a block element
|
||||
if re.match(r"^#{1,6}\s+.*$", line) or line.strip().startswith("```") or re.match(r"^\s*[-*+]\s+.*$", line) or re.match(r"^\s*\d+\.\s+.*$", line) or line.strip().startswith(">"):
|
||||
break
|
||||
elif not line.strip():
|
||||
# check if the next line is a block element
|
||||
if i + 1 < len(self.lines) and (
|
||||
re.match(r"^#{1,6}\s+.*$", self.lines[i + 1])
|
||||
or self.lines[i + 1].strip().startswith("```")
|
||||
or re.match(r"^\s*[-*+]\s+.*$", self.lines[i + 1])
|
||||
or re.match(r"^\s*\d+\.\s+.*$", self.lines[i + 1])
|
||||
or self.lines[i + 1].strip().startswith(">")
|
||||
):
|
||||
break
|
||||
else:
|
||||
content_lines.append(line)
|
||||
end_pos = i
|
||||
i += 1
|
||||
else:
|
||||
content_lines.append(line)
|
||||
end_pos = i
|
||||
i += 1
|
||||
|
||||
return {
|
||||
"type": "text_block",
|
||||
"content": "\n".join(content_lines),
|
||||
"start_line": start_pos,
|
||||
"end_line": end_pos,
|
||||
}
|
||||
524
api/app/core/rag/deepdoc/parser/mineru_parser.py
Normal file
524
api/app/core/rag/deepdoc/parser/mineru_parser.py
Normal file
@@ -0,0 +1,524 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import pdfplumber
|
||||
import requests
|
||||
from PIL import Image
|
||||
from strenum import StrEnum
|
||||
|
||||
from .pdf_parser import RAGPdfParser
|
||||
|
||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||
|
||||
|
||||
class MinerUContentType(StrEnum):
|
||||
IMAGE = "image"
|
||||
TABLE = "table"
|
||||
TEXT = "text"
|
||||
EQUATION = "equation"
|
||||
CODE = "code"
|
||||
LIST = "list"
|
||||
DISCARDED = "discarded"
|
||||
|
||||
|
||||
class MinerUParser(RAGPdfParser):
|
||||
def __init__(self, mineru_path: str = "mineru", mineru_api: str = "http://host.docker.internal:9987", mineru_server_url: str = ""):
|
||||
self.mineru_path = Path(mineru_path)
|
||||
self.mineru_api = mineru_api.rstrip("/")
|
||||
self.mineru_server_url = mineru_server_url.rstrip("/")
|
||||
self.using_api = False
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
def _extract_zip_no_root(self, zip_path, extract_to, root_dir):
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
if not root_dir:
|
||||
files = zip_ref.namelist()
|
||||
if files and files[0].endswith("/"):
|
||||
root_dir = files[0]
|
||||
else:
|
||||
root_dir = None
|
||||
|
||||
if not root_dir or not root_dir.endswith("/"):
|
||||
self.logger.info(f"[MinerU] No root directory found, extracting all...fff{root_dir}")
|
||||
zip_ref.extractall(extract_to)
|
||||
return
|
||||
|
||||
root_len = len(root_dir)
|
||||
for member in zip_ref.infolist():
|
||||
filename = member.filename
|
||||
if filename == root_dir:
|
||||
self.logger.info("[MinerU] Ignore root folder...")
|
||||
continue
|
||||
|
||||
path = filename
|
||||
if path.startswith(root_dir):
|
||||
path = path[root_len:]
|
||||
|
||||
full_path = os.path.join(extract_to, path)
|
||||
if member.is_dir():
|
||||
os.makedirs(full_path, exist_ok=True)
|
||||
else:
|
||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||
with open(full_path, "wb") as f:
|
||||
f.write(zip_ref.read(filename))
|
||||
|
||||
def _is_http_endpoint_valid(self, url, timeout=5):
|
||||
try:
|
||||
response = requests.head(url, timeout=timeout, allow_redirects=True)
|
||||
return response.status_code in [200, 301, 302, 307, 308]
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def check_installation(self, backend: str = "pipeline", server_url: Optional[str] = None) -> tuple[bool, str]:
|
||||
reason = ""
|
||||
|
||||
valid_backends = ["pipeline", "vlm-http-client", "vlm-transformers", "vlm-vllm-engine"]
|
||||
if backend not in valid_backends:
|
||||
reason = "[MinerU] Invalid backend '{backend}'. Valid backends are: {valid_backends}"
|
||||
logging.warning(reason)
|
||||
return False, reason
|
||||
|
||||
subprocess_kwargs = {
|
||||
"capture_output": True,
|
||||
"text": True,
|
||||
"check": True,
|
||||
"encoding": "utf-8",
|
||||
"errors": "ignore",
|
||||
}
|
||||
|
||||
if platform.system() == "Windows":
|
||||
subprocess_kwargs["creationflags"] = getattr(subprocess, "CREATE_NO_WINDOW", 0)
|
||||
|
||||
if server_url is None:
|
||||
server_url = self.mineru_server_url
|
||||
|
||||
if backend == "vlm-http-client" and server_url:
|
||||
try:
|
||||
server_accessible = self._is_http_endpoint_valid(server_url + "/openapi.json")
|
||||
logging.info(f"[MinerU] vlm-http-client server check: {server_accessible}")
|
||||
if server_accessible:
|
||||
self.using_api = False # We are using http client, not API
|
||||
return True, reason
|
||||
else:
|
||||
reason = f"[MinerU] vlm-http-client server not accessible: {server_url}"
|
||||
logging.warning(f"[MinerU] vlm-http-client server not accessible: {server_url}")
|
||||
return False, reason
|
||||
except Exception as e:
|
||||
logging.warning(f"[MinerU] vlm-http-client server check failed: {e}")
|
||||
try:
|
||||
response = requests.get(server_url, timeout=5)
|
||||
logging.info(f"[MinerU] vlm-http-client server connection check: success with status {response.status_code}")
|
||||
self.using_api = False
|
||||
return True, reason
|
||||
except Exception as e:
|
||||
reason = f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}"
|
||||
logging.warning(f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}")
|
||||
return False, reason
|
||||
|
||||
try:
|
||||
result = subprocess.run([str(self.mineru_path), "--version"], **subprocess_kwargs)
|
||||
version_info = result.stdout.strip()
|
||||
if version_info:
|
||||
logging.info(f"[MinerU] Detected version: {version_info}")
|
||||
else:
|
||||
logging.info("[MinerU] Detected MinerU, but version info is empty.")
|
||||
return True, reason
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.warning(f"[MinerU] Execution failed (exit code {e.returncode}).")
|
||||
except FileNotFoundError:
|
||||
logging.warning("[MinerU] MinerU not found. Please install it via: pip install -U 'mineru[core]'")
|
||||
except Exception as e:
|
||||
logging.error(f"[MinerU] Unexpected error during installation check: {e}")
|
||||
|
||||
# If executable check fails, try API check
|
||||
try:
|
||||
if self.mineru_api:
|
||||
# check openapi.json
|
||||
openapi_exists = self._is_http_endpoint_valid(self.mineru_api + "/openapi.json")
|
||||
if not openapi_exists:
|
||||
reason = "[MinerU] Failed to detect vaild MinerU API server"
|
||||
return openapi_exists, reason
|
||||
logging.info(f"[MinerU] Detected {self.mineru_api}/openapi.json: {openapi_exists}")
|
||||
self.using_api = openapi_exists
|
||||
return openapi_exists, reason
|
||||
else:
|
||||
logging.info("[MinerU] api not exists.")
|
||||
except Exception as e:
|
||||
reason = f"[MinerU] Unexpected error during api check: {e}"
|
||||
logging.error(f"[MinerU] Unexpected error during api check: {e}")
|
||||
return False, reason
|
||||
|
||||
def _run_mineru(
|
||||
self, input_path: Path, output_dir: Path, method: str = "auto", backend: str = "pipeline", lang: Optional[str] = None, server_url: Optional[str] = None, callback: Optional[Callable] = None
|
||||
):
|
||||
if self.using_api:
|
||||
self._run_mineru_api(input_path, output_dir, method, backend, lang, callback)
|
||||
else:
|
||||
self._run_mineru_executable(input_path, output_dir, method, backend, lang, server_url, callback)
|
||||
|
||||
def _run_mineru_api(self, input_path: Path, output_dir: Path, method: str = "auto", backend: str = "pipeline", lang: Optional[str] = None, callback: Optional[Callable] = None):
|
||||
OUTPUT_ZIP_PATH = os.path.join(str(output_dir), "output.zip")
|
||||
|
||||
pdf_file_path = str(input_path)
|
||||
|
||||
if not os.path.exists(pdf_file_path):
|
||||
raise RuntimeError(f"[MinerU] PDF file not exists: {pdf_file_path}")
|
||||
|
||||
pdf_file_name = Path(pdf_file_path).stem.strip()
|
||||
output_path = os.path.join(str(output_dir), pdf_file_name, method)
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
files = {"files": (pdf_file_name + ".pdf", open(pdf_file_path, "rb"), "application/pdf")}
|
||||
|
||||
data = {
|
||||
"output_dir": "./output",
|
||||
"lang_list": lang,
|
||||
"backend": backend,
|
||||
"parse_method": method,
|
||||
"formula_enable": True,
|
||||
"table_enable": True,
|
||||
"server_url": None,
|
||||
"return_md": True,
|
||||
"return_middle_json": True,
|
||||
"return_model_output": True,
|
||||
"return_content_list": True,
|
||||
"return_images": True,
|
||||
"response_format_zip": True,
|
||||
"start_page_id": 0,
|
||||
"end_page_id": 99999,
|
||||
}
|
||||
|
||||
headers = {"Accept": "application/json"}
|
||||
try:
|
||||
self.logger.info(f"[MinerU] invoke api: {self.mineru_api}/file_parse")
|
||||
if callback:
|
||||
callback(0.20, f"[MinerU] invoke api: {self.mineru_api}/file_parse")
|
||||
response = requests.post(url=f"{self.mineru_api}/file_parse", files=files, data=data, headers=headers, timeout=1800)
|
||||
|
||||
response.raise_for_status()
|
||||
if response.headers.get("Content-Type") == "application/zip":
|
||||
self.logger.info(f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...")
|
||||
|
||||
if callback:
|
||||
callback(0.30, f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...")
|
||||
|
||||
with open(OUTPUT_ZIP_PATH, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
self.logger.info(f"[MinerU] Unzip to {output_path}...")
|
||||
self._extract_zip_no_root(OUTPUT_ZIP_PATH, output_path, pdf_file_name + "/")
|
||||
|
||||
if callback:
|
||||
callback(0.40, f"[MinerU] Unzip to {output_path}...")
|
||||
else:
|
||||
self.logger.warning("[MinerU] not zip returned from api:%s " % response.headers.get("Content-Type"))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"[MinerU] api failed with exception {e}")
|
||||
self.logger.info("[MinerU] Api completed successfully.")
|
||||
|
||||
def _run_mineru_executable(
|
||||
self, input_path: Path, output_dir: Path, method: str = "auto", backend: str = "pipeline", lang: Optional[str] = None, server_url: Optional[str] = None, callback: Optional[Callable] = None
|
||||
):
|
||||
cmd = [str(self.mineru_path), "-p", str(input_path), "-o", str(output_dir), "-m", method]
|
||||
if backend:
|
||||
cmd.extend(["-b", backend])
|
||||
if lang:
|
||||
cmd.extend(["-l", lang])
|
||||
if server_url and backend == "vlm-http-client":
|
||||
cmd.extend(["-u", server_url])
|
||||
|
||||
self.logger.info(f"[MinerU] Running command: {' '.join(cmd)}")
|
||||
|
||||
subprocess_kwargs = {
|
||||
"stdout": subprocess.PIPE,
|
||||
"stderr": subprocess.PIPE,
|
||||
"text": True,
|
||||
"encoding": "utf-8",
|
||||
"errors": "ignore",
|
||||
"bufsize": 1,
|
||||
}
|
||||
|
||||
if platform.system() == "Windows":
|
||||
subprocess_kwargs["creationflags"] = getattr(subprocess, "CREATE_NO_WINDOW", 0)
|
||||
|
||||
process = subprocess.Popen(cmd, **subprocess_kwargs)
|
||||
stdout_queue, stderr_queue = Queue(), Queue()
|
||||
|
||||
def enqueue_output(pipe, queue, prefix):
|
||||
for line in iter(pipe.readline, ""):
|
||||
if line.strip():
|
||||
queue.put((prefix, line.strip()))
|
||||
pipe.close()
|
||||
|
||||
threading.Thread(target=enqueue_output, args=(process.stdout, stdout_queue, "STDOUT"), daemon=True).start()
|
||||
threading.Thread(target=enqueue_output, args=(process.stderr, stderr_queue, "STDERR"), daemon=True).start()
|
||||
|
||||
while process.poll() is None:
|
||||
for q in (stdout_queue, stderr_queue):
|
||||
try:
|
||||
while True:
|
||||
prefix, line = q.get_nowait()
|
||||
if prefix == "STDOUT":
|
||||
self.logger.info(f"[MinerU] {line}")
|
||||
else:
|
||||
self.logger.warning(f"[MinerU] {line}")
|
||||
except Empty:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
raise RuntimeError(f"[MinerU] Process failed with exit code {return_code}")
|
||||
self.logger.info("[MinerU] Command completed successfully.")
|
||||
|
||||
def __images__(self, fnm, zoomin: int = 1, page_from=0, page_to=600, callback=None):
|
||||
self.page_from = page_from
|
||||
self.page_to = page_to
|
||||
try:
|
||||
with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf:
|
||||
self.pdf = pdf
|
||||
self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])]
|
||||
except Exception as e:
|
||||
self.page_images = None
|
||||
self.total_page = 0
|
||||
logging.exception(e)
|
||||
|
||||
def _line_tag(self, bx):
|
||||
pn = [bx["page_idx"] + 1]
|
||||
positions = bx["bbox"]
|
||||
x0, top, x1, bott = positions
|
||||
|
||||
if hasattr(self, "page_images") and self.page_images and len(self.page_images) > bx["page_idx"]:
|
||||
page_width, page_height = self.page_images[bx["page_idx"]].size
|
||||
x0 = (x0 / 1000.0) * page_width
|
||||
x1 = (x1 / 1000.0) * page_width
|
||||
top = (top / 1000.0) * page_height
|
||||
bott = (bott / 1000.0) * page_height
|
||||
|
||||
return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##".format("-".join([str(p) for p in pn]), x0, x1, top, bott)
|
||||
|
||||
def crop(self, text, ZM=1, need_position=False):
|
||||
imgs = []
|
||||
poss = self.extract_positions(text)
|
||||
if not poss:
|
||||
if need_position:
|
||||
return None, None
|
||||
return
|
||||
|
||||
max_width = max(np.max([right - left for (_, left, right, _, _) in poss]), 6)
|
||||
GAP = 6
|
||||
pos = poss[0]
|
||||
poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0)))
|
||||
pos = poss[-1]
|
||||
poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1], pos[4] + GAP), min(self.page_images[pos[0][-1]].size[1], pos[4] + 120)))
|
||||
|
||||
positions = []
|
||||
for ii, (pns, left, right, top, bottom) in enumerate(poss):
|
||||
right = left + max_width
|
||||
|
||||
if bottom <= top:
|
||||
bottom = top + 2
|
||||
|
||||
for pn in pns[1:]:
|
||||
bottom += self.page_images[pn - 1].size[1]
|
||||
|
||||
img0 = self.page_images[pns[0]]
|
||||
x0, y0, x1, y1 = int(left), int(top), int(right), int(min(bottom, img0.size[1]))
|
||||
crop0 = img0.crop((x0, y0, x1, y1))
|
||||
imgs.append(crop0)
|
||||
if 0 < ii < len(poss) - 1:
|
||||
positions.append((pns[0] + self.page_from, x0, x1, y0, y1))
|
||||
|
||||
bottom -= img0.size[1]
|
||||
for pn in pns[1:]:
|
||||
page = self.page_images[pn]
|
||||
x0, y0, x1, y1 = int(left), 0, int(right), int(min(bottom, page.size[1]))
|
||||
cimgp = page.crop((x0, y0, x1, y1))
|
||||
imgs.append(cimgp)
|
||||
if 0 < ii < len(poss) - 1:
|
||||
positions.append((pn + self.page_from, x0, x1, y0, y1))
|
||||
bottom -= page.size[1]
|
||||
|
||||
if not imgs:
|
||||
if need_position:
|
||||
return None, None
|
||||
return
|
||||
|
||||
height = 0
|
||||
for img in imgs:
|
||||
height += img.size[1] + GAP
|
||||
height = int(height)
|
||||
width = int(np.max([i.size[0] for i in imgs]))
|
||||
pic = Image.new("RGB", (width, height), (245, 245, 245))
|
||||
height = 0
|
||||
for ii, img in enumerate(imgs):
|
||||
if ii == 0 or ii + 1 == len(imgs):
|
||||
img = img.convert("RGBA")
|
||||
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
||||
overlay.putalpha(128)
|
||||
img = Image.alpha_composite(img, overlay).convert("RGB")
|
||||
pic.paste(img, (0, int(height)))
|
||||
height += img.size[1] + GAP
|
||||
|
||||
if need_position:
|
||||
return pic, positions
|
||||
return pic
|
||||
|
||||
@staticmethod
|
||||
def extract_positions(txt: str):
|
||||
poss = []
|
||||
for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", txt):
|
||||
pn, left, right, top, bottom = tag.strip("#").strip("@").split("\t")
|
||||
left, right, top, bottom = float(left), float(right), float(top), float(bottom)
|
||||
poss.append(([int(p) - 1 for p in pn.split("-")], left, right, top, bottom))
|
||||
return poss
|
||||
|
||||
def _read_output(self, output_dir: Path, file_stem: str, method: str = "auto", backend: str = "pipeline") -> list[dict[str, Any]]:
|
||||
subdir = output_dir / file_stem / method
|
||||
if backend.startswith("vlm-"):
|
||||
subdir = output_dir / file_stem / "vlm"
|
||||
json_file = subdir / f"{file_stem}_content_list.json"
|
||||
|
||||
if not json_file.exists():
|
||||
raise FileNotFoundError(f"[MinerU] Missing output file: {json_file}")
|
||||
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for item in data:
|
||||
for key in ("img_path", "table_img_path", "equation_img_path"):
|
||||
if key in item and item[key]:
|
||||
item[key] = str((subdir / item[key]).resolve())
|
||||
return data
|
||||
|
||||
def _transfer_to_sections(self, outputs: list[dict[str, Any]]):
|
||||
sections = []
|
||||
for output in outputs:
|
||||
match output["type"]:
|
||||
case MinerUContentType.TEXT:
|
||||
section = output["text"]
|
||||
case MinerUContentType.TABLE:
|
||||
section = output.get("table_body", "") + "\n".join(output.get("table_caption", [])) + "\n".join(output.get("table_footnote", []))
|
||||
if not section.strip():
|
||||
section = "FAILED TO PARSE TABLE"
|
||||
case MinerUContentType.IMAGE:
|
||||
section = "".join(output["image_caption"]) + "\n" + "".join(output["image_footnote"])
|
||||
case MinerUContentType.EQUATION:
|
||||
section = output["text"]
|
||||
case MinerUContentType.CODE:
|
||||
section = output["code_body"] + "\n".join(output.get("code_caption", []))
|
||||
case MinerUContentType.LIST:
|
||||
section = "\n".join(output.get("list_items", []))
|
||||
case MinerUContentType.DISCARDED:
|
||||
pass
|
||||
|
||||
if section:
|
||||
sections.append((section, self._line_tag(output)))
|
||||
return sections
|
||||
|
||||
def _transfer_to_tables(self, outputs: list[dict[str, Any]]):
|
||||
return []
|
||||
|
||||
def parse_pdf(
|
||||
self,
|
||||
filepath: str | PathLike[str],
|
||||
binary: BytesIO | bytes,
|
||||
callback: Optional[Callable] = None,
|
||||
*,
|
||||
output_dir: Optional[str] = None,
|
||||
backend: str = "pipeline",
|
||||
lang: Optional[str] = None,
|
||||
method: str = "auto",
|
||||
server_url: Optional[str] = None,
|
||||
delete_output: bool = True,
|
||||
) -> tuple:
|
||||
import shutil
|
||||
|
||||
temp_pdf = None
|
||||
created_tmp_dir = False
|
||||
|
||||
# remove spaces, or mineru crash, and _read_output fail too
|
||||
file_path = Path(filepath)
|
||||
pdf_file_name = file_path.stem.replace(" ", "") + ".pdf"
|
||||
pdf_file_path_valid = os.path.join(file_path.parent, pdf_file_name)
|
||||
|
||||
if binary:
|
||||
temp_dir = Path(tempfile.mkdtemp(prefix="mineru_bin_pdf_"))
|
||||
temp_pdf = temp_dir / pdf_file_name
|
||||
with open(temp_pdf, "wb") as f:
|
||||
f.write(binary)
|
||||
pdf = temp_pdf
|
||||
self.logger.info(f"[MinerU] Received binary PDF -> {temp_pdf}")
|
||||
if callback:
|
||||
callback(0.15, f"[MinerU] Received binary PDF -> {temp_pdf}")
|
||||
else:
|
||||
if pdf_file_path_valid != filepath:
|
||||
self.logger.info(f"[MinerU] Remove all space in file name: {pdf_file_path_valid}")
|
||||
shutil.move(filepath, pdf_file_path_valid)
|
||||
pdf = Path(pdf_file_path_valid)
|
||||
if not pdf.exists():
|
||||
if callback:
|
||||
callback(-1, f"[MinerU] PDF not found: {pdf}")
|
||||
raise FileNotFoundError(f"[MinerU] PDF not found: {pdf}")
|
||||
|
||||
if output_dir:
|
||||
out_dir = Path(output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
out_dir = Path(tempfile.mkdtemp(prefix="mineru_pdf_"))
|
||||
created_tmp_dir = True
|
||||
|
||||
self.logger.info(f"[MinerU] Output directory: {out_dir}")
|
||||
if callback:
|
||||
callback(0.15, f"[MinerU] Output directory: {out_dir}")
|
||||
|
||||
self.__images__(pdf, zoomin=1)
|
||||
|
||||
try:
|
||||
self._run_mineru(pdf, out_dir, method=method, backend=backend, lang=lang, server_url=server_url, callback=callback)
|
||||
outputs = self._read_output(out_dir, pdf.stem, method=method, backend=backend)
|
||||
self.logger.info(f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
|
||||
if callback:
|
||||
callback(0.75, f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
|
||||
return self._transfer_to_sections(outputs), self._transfer_to_tables(outputs)
|
||||
finally:
|
||||
if temp_pdf and temp_pdf.exists():
|
||||
try:
|
||||
temp_pdf.unlink()
|
||||
temp_pdf.parent.rmdir()
|
||||
except Exception:
|
||||
pass
|
||||
if delete_output and created_tmp_dir and out_dir.exists():
|
||||
try:
|
||||
shutil.rmtree(out_dir)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = MinerUParser("mineru")
|
||||
ok, reason = parser.check_installation()
|
||||
print("MinerU available:", ok)
|
||||
|
||||
filepath = ""
|
||||
with open(filepath, "rb") as file:
|
||||
outputs = parser.parse_pdf(filepath=filepath, binary=file.read())
|
||||
for output in outputs:
|
||||
print(output)
|
||||
1387
api/app/core/rag/deepdoc/parser/pdf_parser.py
Normal file
1387
api/app/core/rag/deepdoc/parser/pdf_parser.py
Normal file
File diff suppressed because it is too large
Load Diff
83
api/app/core/rag/deepdoc/parser/ppt_parser.py
Normal file
83
api/app/core/rag/deepdoc/parser/ppt_parser.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from pptx import Presentation
|
||||
|
||||
|
||||
class RAGPptParser:
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __get_bulleted_text(self, paragraph):
|
||||
is_bulleted = bool(paragraph._p.xpath("./a:pPr/a:buChar")) or bool(paragraph._p.xpath("./a:pPr/a:buAutoNum")) or bool(paragraph._p.xpath("./a:pPr/a:buBlip"))
|
||||
if is_bulleted:
|
||||
return f"{' '* paragraph.level}.{paragraph.text}"
|
||||
else:
|
||||
return paragraph.text
|
||||
|
||||
def __extract(self, shape):
|
||||
try:
|
||||
# First try to get text content
|
||||
if hasattr(shape, 'has_text_frame') and shape.has_text_frame:
|
||||
text_frame = shape.text_frame
|
||||
texts = []
|
||||
for paragraph in text_frame.paragraphs:
|
||||
if paragraph.text.strip():
|
||||
texts.append(self.__get_bulleted_text(paragraph))
|
||||
return "\n".join(texts)
|
||||
|
||||
# Safely get shape_type
|
||||
try:
|
||||
shape_type = shape.shape_type
|
||||
except NotImplementedError:
|
||||
# If shape_type is not available, try to get text content
|
||||
if hasattr(shape, 'text'):
|
||||
return shape.text.strip()
|
||||
return ""
|
||||
|
||||
# Handle table
|
||||
if shape_type == 19:
|
||||
tb = shape.table
|
||||
rows = []
|
||||
for i in range(1, len(tb.rows)):
|
||||
rows.append("; ".join([tb.cell(
|
||||
0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
|
||||
return "\n".join(rows)
|
||||
|
||||
# Handle group shape
|
||||
if shape_type == 6:
|
||||
texts = []
|
||||
for p in sorted(shape.shapes, key=lambda x: (x.top // 10, x.left)):
|
||||
t = self.__extract(p)
|
||||
if t:
|
||||
texts.append(t)
|
||||
return "\n".join(texts)
|
||||
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing shape: {str(e)}")
|
||||
return ""
|
||||
|
||||
def __call__(self, fnm, from_page, to_page, callback=None):
|
||||
ppt = Presentation(fnm) if isinstance(
|
||||
fnm, str) else Presentation(
|
||||
BytesIO(fnm))
|
||||
txts = []
|
||||
self.total_page = len(ppt.slides)
|
||||
for i, slide in enumerate(ppt.slides):
|
||||
if i < from_page:
|
||||
continue
|
||||
if i >= to_page:
|
||||
break
|
||||
texts = []
|
||||
for shape in sorted(
|
||||
slide.shapes, key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left if x.left is not None else 0)):
|
||||
try:
|
||||
txt = self.__extract(shape)
|
||||
if txt:
|
||||
texts.append(txt)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
txts.append("\n".join(texts))
|
||||
|
||||
return txts
|
||||
48
api/app/core/rag/deepdoc/parser/txt_parser.py
Normal file
48
api/app/core/rag/deepdoc/parser/txt_parser.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import re
|
||||
|
||||
from .utils import get_text
|
||||
from app.core.rag.common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
class RAGTxtParser:
|
||||
def __call__(self, fnm, binary=None, chunk_token_num=128, delimiter="\n!?;。;!?"):
|
||||
txt = get_text(fnm, binary)
|
||||
return self.parser_txt(txt, chunk_token_num, delimiter)
|
||||
|
||||
@classmethod
|
||||
def parser_txt(cls, txt, chunk_token_num=128, delimiter="\n!?;。;!?"):
|
||||
if not isinstance(txt, str):
|
||||
raise TypeError("txt type should be str!")
|
||||
cks = [""]
|
||||
tk_nums = [0]
|
||||
delimiter = delimiter.encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8')
|
||||
|
||||
def add_chunk(t):
|
||||
nonlocal cks, tk_nums, delimiter
|
||||
tnum = num_tokens_from_string(t)
|
||||
if tk_nums[-1] > chunk_token_num:
|
||||
cks.append(t)
|
||||
tk_nums.append(tnum)
|
||||
else:
|
||||
cks[-1] += t
|
||||
tk_nums[-1] += tnum
|
||||
|
||||
dels = []
|
||||
s = 0
|
||||
for m in re.finditer(r"`([^`]+)`", delimiter, re.I):
|
||||
f, t = m.span()
|
||||
dels.append(m.group(1))
|
||||
dels.extend(list(delimiter[s: f]))
|
||||
s = t
|
||||
if s < len(delimiter):
|
||||
dels.extend(list(delimiter[s:]))
|
||||
dels = [re.escape(d) for d in dels if d]
|
||||
dels = [d for d in dels if d]
|
||||
dels = "|".join(dels)
|
||||
secs = re.split(r"(%s)" % dels, txt)
|
||||
for sec in secs:
|
||||
if re.match(f"^{dels}$", sec):
|
||||
continue
|
||||
add_chunk(sec)
|
||||
|
||||
return [[c, ""] for c in cks]
|
||||
16
api/app/core/rag/deepdoc/parser/utils.py
Normal file
16
api/app/core/rag/deepdoc/parser/utils.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from app.core.rag.nlp import find_codec
|
||||
|
||||
|
||||
def get_text(fnm: str, binary=None) -> str:
|
||||
txt = ""
|
||||
if binary:
|
||||
encoding = find_codec(binary)
|
||||
txt = binary.decode(encoding, errors="ignore")
|
||||
else:
|
||||
with open(fnm, "r") as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
txt += line
|
||||
return txt
|
||||
75
api/app/core/rag/deepdoc/vision/__init__.py
Normal file
75
api/app/core/rag/deepdoc/vision/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import io
|
||||
import sys
|
||||
import threading
|
||||
|
||||
import pdfplumber
|
||||
|
||||
from .ocr import OCR
|
||||
from .recognizer import Recognizer
|
||||
from .layout_recognizer import AscendLayoutRecognizer
|
||||
from .layout_recognizer import LayoutRecognizer4YOLOv10 as LayoutRecognizer
|
||||
from .table_structure_recognizer import TableStructureRecognizer
|
||||
|
||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||
|
||||
|
||||
def init_in_out(args):
|
||||
import os
|
||||
import traceback
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from app.core.rag.common.file_utils import traversal_files
|
||||
|
||||
images = []
|
||||
outputs = []
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.mkdir(args.output_dir)
|
||||
|
||||
def pdf_pages(fnm, zoomin=3):
|
||||
nonlocal outputs, images
|
||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||
pdf = pdfplumber.open(fnm)
|
||||
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in enumerate(pdf.pages)]
|
||||
|
||||
for i, page in enumerate(images):
|
||||
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")
|
||||
pdf.close()
|
||||
|
||||
def images_and_outputs(fnm):
|
||||
nonlocal outputs, images
|
||||
if fnm.split(".")[-1].lower() == "pdf":
|
||||
pdf_pages(fnm)
|
||||
return
|
||||
try:
|
||||
fp = open(fnm, "rb")
|
||||
binary = fp.read()
|
||||
fp.close()
|
||||
images.append(Image.open(io.BytesIO(binary)).convert("RGB"))
|
||||
outputs.append(os.path.split(fnm)[-1])
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
if os.path.isdir(args.inputs):
|
||||
for fnm in traversal_files(args.inputs):
|
||||
images_and_outputs(fnm)
|
||||
else:
|
||||
images_and_outputs(args.inputs)
|
||||
|
||||
for i in range(len(outputs)):
|
||||
outputs[i] = os.path.join(args.output_dir, outputs[i])
|
||||
|
||||
return images, outputs
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OCR",
|
||||
"Recognizer",
|
||||
"LayoutRecognizer",
|
||||
"AscendLayoutRecognizer",
|
||||
"TableStructureRecognizer",
|
||||
"init_in_out",
|
||||
]
|
||||
440
api/app/core/rag/deepdoc/vision/layout_recognizer.py
Normal file
440
api/app/core/rag/deepdoc/vision/layout_recognizer.py
Normal file
@@ -0,0 +1,440 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from app.core.rag.common.file_utils import get_project_base_directory
|
||||
from . import Recognizer
|
||||
from .operators import nms
|
||||
|
||||
|
||||
class LayoutRecognizer(Recognizer):
|
||||
labels = [
|
||||
"_background_",
|
||||
"Text",
|
||||
"Title",
|
||||
"Figure",
|
||||
"Figure caption",
|
||||
"Table",
|
||||
"Table caption",
|
||||
"Header",
|
||||
"Footer",
|
||||
"Reference",
|
||||
"Equation",
|
||||
]
|
||||
|
||||
def __init__(self, domain):
|
||||
try:
|
||||
model_dir = os.path.join(get_project_base_directory(), "res/deepdoc")
|
||||
super().__init__(self.labels, domain, model_dir)
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "res/deepdoc"), local_dir_use_symlinks=False)
|
||||
super().__init__(self.labels, domain, model_dir)
|
||||
|
||||
self.garbage_layouts = ["footer", "header", "reference"]
|
||||
self.client = None
|
||||
if os.environ.get("TENSORRT_DLA_SVR"):
|
||||
from deepdoc.vision.dla_cli import DLAClient
|
||||
|
||||
self.client = DLAClient(os.environ["TENSORRT_DLA_SVR"])
|
||||
|
||||
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
|
||||
def __is_garbage(b):
|
||||
patt = [r"^•+$", "^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", "\\(cid *: *[0-9]+ *\\)"]
|
||||
return any([re.search(p, b["text"]) for p in patt])
|
||||
|
||||
if self.client:
|
||||
layouts = self.client.predict(image_list)
|
||||
else:
|
||||
layouts = super().__call__(image_list, thr, batch_size)
|
||||
# save_results(image_list, layouts, self.labels, output_dir='output/', threshold=0.7)
|
||||
assert len(image_list) == len(ocr_res)
|
||||
# Tag layout type
|
||||
boxes = []
|
||||
assert len(image_list) == len(layouts)
|
||||
garbages = {}
|
||||
page_layout = []
|
||||
for pn, lts in enumerate(layouts):
|
||||
bxs = ocr_res[pn]
|
||||
lts = [
|
||||
{
|
||||
"type": b["type"],
|
||||
"score": float(b["score"]),
|
||||
"x0": b["bbox"][0] / scale_factor,
|
||||
"x1": b["bbox"][2] / scale_factor,
|
||||
"top": b["bbox"][1] / scale_factor,
|
||||
"bottom": b["bbox"][-1] / scale_factor,
|
||||
"page_number": pn,
|
||||
}
|
||||
for b in lts
|
||||
if float(b["score"]) >= 0.4 or b["type"] not in self.garbage_layouts
|
||||
]
|
||||
lts = self.sort_Y_firstly(lts, np.mean([lt["bottom"] - lt["top"] for lt in lts]) / 2)
|
||||
lts = self.layouts_cleanup(bxs, lts)
|
||||
page_layout.append(lts)
|
||||
|
||||
def findLayout(ty):
|
||||
nonlocal bxs, lts, self
|
||||
lts_ = [lt for lt in lts if lt["type"] == ty]
|
||||
i = 0
|
||||
while i < len(bxs):
|
||||
if bxs[i].get("layout_type"):
|
||||
i += 1
|
||||
continue
|
||||
if __is_garbage(bxs[i]):
|
||||
bxs.pop(i)
|
||||
continue
|
||||
|
||||
ii = self.find_overlapped_with_threshold(bxs[i], lts_, thr=0.4)
|
||||
if ii is None:
|
||||
bxs[i]["layout_type"] = ""
|
||||
i += 1
|
||||
continue
|
||||
lts_[ii]["visited"] = True
|
||||
keep_feats = [
|
||||
lts_[ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor,
|
||||
lts_[ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor,
|
||||
]
|
||||
if drop and lts_[ii]["type"] in self.garbage_layouts and not any(keep_feats):
|
||||
if lts_[ii]["type"] not in garbages:
|
||||
garbages[lts_[ii]["type"]] = []
|
||||
garbages[lts_[ii]["type"]].append(bxs[i]["text"])
|
||||
bxs.pop(i)
|
||||
continue
|
||||
|
||||
bxs[i]["layoutno"] = f"{ty}-{ii}"
|
||||
bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[ii]["type"] != "equation" else "figure"
|
||||
i += 1
|
||||
|
||||
for lt in ["footer", "header", "reference", "figure caption", "table caption", "title", "table", "text", "figure", "equation"]:
|
||||
findLayout(lt)
|
||||
|
||||
# add box to figure layouts which has not text box
|
||||
for i, lt in enumerate([lt for lt in lts if lt["type"] in ["figure", "equation"]]):
|
||||
if lt.get("visited"):
|
||||
continue
|
||||
lt = deepcopy(lt)
|
||||
del lt["type"]
|
||||
lt["text"] = ""
|
||||
lt["layout_type"] = "figure"
|
||||
lt["layoutno"] = f"figure-{i}"
|
||||
bxs.append(lt)
|
||||
|
||||
boxes.extend(bxs)
|
||||
|
||||
ocr_res = boxes
|
||||
|
||||
garbag_set = set()
|
||||
for k in garbages.keys():
|
||||
garbages[k] = Counter(garbages[k])
|
||||
for g, c in garbages[k].items():
|
||||
if c > 1:
|
||||
garbag_set.add(g)
|
||||
|
||||
ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
|
||||
return ocr_res, page_layout
|
||||
|
||||
def forward(self, image_list, thr=0.7, batch_size=16):
|
||||
return super().__call__(image_list, thr, batch_size)
|
||||
|
||||
|
||||
class LayoutRecognizer4YOLOv10(LayoutRecognizer):
|
||||
labels = [
|
||||
"title",
|
||||
"Text",
|
||||
"Reference",
|
||||
"Figure",
|
||||
"Figure caption",
|
||||
"Table",
|
||||
"Table caption",
|
||||
"Table caption",
|
||||
"Equation",
|
||||
"Figure caption",
|
||||
]
|
||||
|
||||
def __init__(self, domain):
|
||||
domain = "layout"
|
||||
super().__init__(domain)
|
||||
self.auto = False
|
||||
self.scaleFill = False
|
||||
self.scaleup = True
|
||||
self.stride = 32
|
||||
self.center = True
|
||||
|
||||
def preprocess(self, image_list):
|
||||
inputs = []
|
||||
new_shape = self.input_shape # height, width
|
||||
for img in image_list:
|
||||
shape = img.shape[:2] # current shape [height, width]
|
||||
# Scale ratio (new / old)
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
# Compute padding
|
||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
||||
dw /= 2 # divide padding into 2 sides
|
||||
dh /= 2
|
||||
ww, hh = new_unpad
|
||||
img = np.array(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).astype(np.float32)
|
||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
|
||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) # add border
|
||||
img /= 255.0
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = img[np.newaxis, :, :, :].astype(np.float32)
|
||||
inputs.append({self.input_names[0]: img, "scale_factor": [shape[1] / ww, shape[0] / hh, dw, dh]})
|
||||
|
||||
return inputs
|
||||
|
||||
def postprocess(self, boxes, inputs, thr):
|
||||
thr = 0.08
|
||||
boxes = np.squeeze(boxes)
|
||||
scores = boxes[:, 4]
|
||||
boxes = boxes[scores > thr, :]
|
||||
scores = scores[scores > thr]
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
class_ids = boxes[:, -1].astype(int)
|
||||
boxes = boxes[:, :4]
|
||||
boxes[:, 0] -= inputs["scale_factor"][2]
|
||||
boxes[:, 2] -= inputs["scale_factor"][2]
|
||||
boxes[:, 1] -= inputs["scale_factor"][3]
|
||||
boxes[:, 3] -= inputs["scale_factor"][3]
|
||||
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]])
|
||||
boxes = np.multiply(boxes, input_shape, dtype=np.float32)
|
||||
|
||||
unique_class_ids = np.unique(class_ids)
|
||||
indices = []
|
||||
for class_id in unique_class_ids:
|
||||
class_indices = np.where(class_ids == class_id)[0]
|
||||
class_boxes = boxes[class_indices, :]
|
||||
class_scores = scores[class_indices]
|
||||
class_keep_boxes = nms(class_boxes, class_scores, 0.45)
|
||||
indices.extend(class_indices[class_keep_boxes])
|
||||
|
||||
return [{"type": self.label_list[class_ids[i]].lower(), "bbox": [float(t) for t in boxes[i].tolist()], "score": float(scores[i])} for i in indices]
|
||||
|
||||
|
||||
class AscendLayoutRecognizer(Recognizer):
|
||||
labels = [
|
||||
"title",
|
||||
"Text",
|
||||
"Reference",
|
||||
"Figure",
|
||||
"Figure caption",
|
||||
"Table",
|
||||
"Table caption",
|
||||
"Table caption",
|
||||
"Equation",
|
||||
"Figure caption",
|
||||
]
|
||||
|
||||
def __init__(self, domain):
|
||||
from ais_bench.infer.interface import InferSession
|
||||
|
||||
model_dir = os.path.join(get_project_base_directory(), "res/deepdoc")
|
||||
model_file_path = os.path.join(model_dir, domain + ".om")
|
||||
|
||||
if not os.path.exists(model_file_path):
|
||||
raise ValueError(f"Model file not found: {model_file_path}")
|
||||
|
||||
device_id = int(os.getenv("ASCEND_LAYOUT_RECOGNIZER_DEVICE_ID", 0))
|
||||
self.session = InferSession(device_id=device_id, model_path=model_file_path)
|
||||
self.input_shape = self.session.get_inputs()[0].shape[2:4] # H,W
|
||||
self.garbage_layouts = ["footer", "header", "reference"]
|
||||
|
||||
def preprocess(self, image_list):
|
||||
inputs = []
|
||||
H, W = self.input_shape
|
||||
for img in image_list:
|
||||
h, w = img.shape[:2]
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
|
||||
|
||||
r = min(H / h, W / w)
|
||||
new_unpad = (int(round(w * r)), int(round(h * r)))
|
||||
dw, dh = (W - new_unpad[0]) / 2.0, (H - new_unpad[1]) / 2.0
|
||||
|
||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
|
||||
|
||||
img /= 255.0
|
||||
img = img.transpose(2, 0, 1)[np.newaxis, :, :, :].astype(np.float32)
|
||||
|
||||
inputs.append(
|
||||
{
|
||||
"image": img,
|
||||
"scale_factor": [w / new_unpad[0], h / new_unpad[1]],
|
||||
"pad": [dw, dh],
|
||||
"orig_shape": [h, w],
|
||||
}
|
||||
)
|
||||
return inputs
|
||||
|
||||
def postprocess(self, boxes, inputs, thr=0.25):
|
||||
arr = np.squeeze(boxes)
|
||||
if arr.ndim == 1:
|
||||
arr = arr.reshape(1, -1)
|
||||
|
||||
results = []
|
||||
if arr.shape[1] == 6:
|
||||
# [x1,y1,x2,y2,score,cls]
|
||||
m = arr[:, 4] >= thr
|
||||
arr = arr[m]
|
||||
if arr.size == 0:
|
||||
return []
|
||||
xyxy = arr[:, :4].astype(np.float32)
|
||||
scores = arr[:, 4].astype(np.float32)
|
||||
cls_ids = arr[:, 5].astype(np.int32)
|
||||
|
||||
if "pad" in inputs:
|
||||
dw, dh = inputs["pad"]
|
||||
sx, sy = inputs["scale_factor"]
|
||||
xyxy[:, [0, 2]] -= dw
|
||||
xyxy[:, [1, 3]] -= dh
|
||||
xyxy *= np.array([sx, sy, sx, sy], dtype=np.float32)
|
||||
else:
|
||||
# backup
|
||||
sx, sy = inputs["scale_factor"]
|
||||
xyxy *= np.array([sx, sy, sx, sy], dtype=np.float32)
|
||||
|
||||
keep_indices = []
|
||||
for c in np.unique(cls_ids):
|
||||
idx = np.where(cls_ids == c)[0]
|
||||
k = nms(xyxy[idx], scores[idx], 0.45)
|
||||
keep_indices.extend(idx[k])
|
||||
|
||||
for i in keep_indices:
|
||||
cid = int(cls_ids[i])
|
||||
if 0 <= cid < len(self.labels):
|
||||
results.append({"type": self.labels[cid].lower(), "bbox": [float(t) for t in xyxy[i].tolist()], "score": float(scores[i])})
|
||||
return results
|
||||
|
||||
raise ValueError(f"Unexpected output shape: {arr.shape}")
|
||||
|
||||
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
|
||||
import re
|
||||
from collections import Counter
|
||||
|
||||
assert len(image_list) == len(ocr_res)
|
||||
|
||||
images = [np.array(im) if not isinstance(im, np.ndarray) else im for im in image_list]
|
||||
layouts_all_pages = [] # list of list[{"type","score","bbox":[x1,y1,x2,y2]}]
|
||||
|
||||
conf_thr = max(thr, 0.08)
|
||||
|
||||
batch_loop_cnt = math.ceil(float(len(images)) / batch_size)
|
||||
for bi in range(batch_loop_cnt):
|
||||
s = bi * batch_size
|
||||
e = min((bi + 1) * batch_size, len(images))
|
||||
batch_images = images[s:e]
|
||||
|
||||
inputs_list = self.preprocess(batch_images)
|
||||
logging.debug("preprocess done")
|
||||
|
||||
for ins in inputs_list:
|
||||
feeds = [ins["image"]]
|
||||
out_list = self.session.infer(feeds=feeds, mode="static")
|
||||
|
||||
for out in out_list:
|
||||
lts = self.postprocess(out, ins, conf_thr)
|
||||
|
||||
page_lts = []
|
||||
for b in lts:
|
||||
if float(b["score"]) >= 0.4 or b["type"] not in self.garbage_layouts:
|
||||
x0, y0, x1, y1 = b["bbox"]
|
||||
page_lts.append(
|
||||
{
|
||||
"type": b["type"],
|
||||
"score": float(b["score"]),
|
||||
"x0": float(x0) / scale_factor,
|
||||
"x1": float(x1) / scale_factor,
|
||||
"top": float(y0) / scale_factor,
|
||||
"bottom": float(y1) / scale_factor,
|
||||
"page_number": len(layouts_all_pages),
|
||||
}
|
||||
)
|
||||
layouts_all_pages.append(page_lts)
|
||||
|
||||
def _is_garbage_text(box):
|
||||
patt = [r"^•+$", r"^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", r"^http://[^ ]{12,}", r"\(cid *: *[0-9]+ *\)"]
|
||||
return any(re.search(p, box.get("text", "")) for p in patt)
|
||||
|
||||
boxes_out = []
|
||||
page_layout = []
|
||||
garbages = {}
|
||||
|
||||
for pn, lts in enumerate(layouts_all_pages):
|
||||
if lts:
|
||||
avg_h = np.mean([lt["bottom"] - lt["top"] for lt in lts])
|
||||
lts = self.sort_Y_firstly(lts, avg_h / 2 if avg_h > 0 else 0)
|
||||
|
||||
bxs = ocr_res[pn]
|
||||
lts = self.layouts_cleanup(bxs, lts)
|
||||
page_layout.append(lts)
|
||||
|
||||
def _tag_layout(ty):
|
||||
nonlocal bxs, lts
|
||||
lts_of_ty = [lt for lt in lts if lt["type"] == ty]
|
||||
i = 0
|
||||
while i < len(bxs):
|
||||
if bxs[i].get("layout_type"):
|
||||
i += 1
|
||||
continue
|
||||
if _is_garbage_text(bxs[i]):
|
||||
bxs.pop(i)
|
||||
continue
|
||||
|
||||
ii = self.find_overlapped_with_threshold(bxs[i], lts_of_ty, thr=0.4)
|
||||
if ii is None:
|
||||
bxs[i]["layout_type"] = ""
|
||||
i += 1
|
||||
continue
|
||||
|
||||
lts_of_ty[ii]["visited"] = True
|
||||
|
||||
keep_feats = [
|
||||
lts_of_ty[ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].shape[0] * 0.9 / scale_factor,
|
||||
lts_of_ty[ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].shape[0] * 0.1 / scale_factor,
|
||||
]
|
||||
if drop and lts_of_ty[ii]["type"] in self.garbage_layouts and not any(keep_feats):
|
||||
garbages.setdefault(lts_of_ty[ii]["type"], []).append(bxs[i].get("text", ""))
|
||||
bxs.pop(i)
|
||||
continue
|
||||
|
||||
bxs[i]["layoutno"] = f"{ty}-{ii}"
|
||||
bxs[i]["layout_type"] = lts_of_ty[ii]["type"] if lts_of_ty[ii]["type"] != "equation" else "figure"
|
||||
i += 1
|
||||
|
||||
for ty in ["footer", "header", "reference", "figure caption", "table caption", "title", "table", "text", "figure", "equation"]:
|
||||
_tag_layout(ty)
|
||||
|
||||
figs = [lt for lt in lts if lt["type"] in ["figure", "equation"]]
|
||||
for i, lt in enumerate(figs):
|
||||
if lt.get("visited"):
|
||||
continue
|
||||
lt = deepcopy(lt)
|
||||
lt.pop("type", None)
|
||||
lt["text"] = ""
|
||||
lt["layout_type"] = "figure"
|
||||
lt["layoutno"] = f"figure-{i}"
|
||||
bxs.append(lt)
|
||||
|
||||
boxes_out.extend(bxs)
|
||||
|
||||
garbag_set = set()
|
||||
for k, lst in garbages.items():
|
||||
cnt = Counter(lst)
|
||||
for g, c in cnt.items():
|
||||
if c > 1:
|
||||
garbag_set.add(g)
|
||||
|
||||
ocr_res_new = [b for b in boxes_out if b["text"].strip() not in garbag_set]
|
||||
return ocr_res_new, page_layout
|
||||
737
api/app/core/rag/deepdoc/vision/ocr.py
Normal file
737
api/app/core/rag/deepdoc/vision/ocr.py
Normal file
@@ -0,0 +1,737 @@
|
||||
import gc
|
||||
import logging
|
||||
import copy
|
||||
import time
|
||||
import os
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from app.core.rag.common.file_utils import get_project_base_directory
|
||||
from app.core.rag.common.misc_utils import pip_install_torch
|
||||
from app.core.rag.common import settings
|
||||
from .operators import * # noqa: F403
|
||||
from . import operators
|
||||
import math
|
||||
import numpy as np
|
||||
import cv2
|
||||
import onnxruntime as ort
|
||||
|
||||
from .postprocess import build_post_process
|
||||
|
||||
loaded_models = {}
|
||||
|
||||
def transform(data, ops=None):
|
||||
""" transform """
|
||||
if ops is None:
|
||||
ops = []
|
||||
for op in ops:
|
||||
data = op(data)
|
||||
if data is None:
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
def create_operators(op_param_list, global_config=None):
|
||||
"""
|
||||
create operators based on the config
|
||||
|
||||
Args:
|
||||
params(list): a dict list, used to create some operators
|
||||
"""
|
||||
assert isinstance(
|
||||
op_param_list, list), ('operator config should be a list')
|
||||
ops = []
|
||||
for operator in op_param_list:
|
||||
assert isinstance(operator,
|
||||
dict) and len(operator) == 1, "yaml format error"
|
||||
op_name = list(operator)[0]
|
||||
param = {} if operator[op_name] is None else operator[op_name]
|
||||
if global_config is not None:
|
||||
param.update(global_config)
|
||||
op = getattr(operators, op_name)(**param)
|
||||
ops.append(op)
|
||||
return ops
|
||||
|
||||
|
||||
def load_model(model_dir, nm, device_id: int | None = None):
|
||||
model_file_path = os.path.join(model_dir, nm + ".onnx")
|
||||
model_cached_tag = model_file_path + str(device_id) if device_id is not None else model_file_path
|
||||
|
||||
global loaded_models
|
||||
loaded_model = loaded_models.get(model_cached_tag)
|
||||
if loaded_model:
|
||||
logging.info(f"load_model {model_file_path} reuses cached model")
|
||||
return loaded_model
|
||||
|
||||
if not os.path.exists(model_file_path):
|
||||
raise ValueError("not find model file path {}".format(
|
||||
model_file_path))
|
||||
|
||||
def cuda_is_available():
|
||||
try:
|
||||
pip_install_torch()
|
||||
import torch
|
||||
target_id = 0 if device_id is None else device_id
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > target_id:
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
options = ort.SessionOptions()
|
||||
options.enable_cpu_mem_arena = False
|
||||
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
options.intra_op_num_threads = 2
|
||||
options.inter_op_num_threads = 2
|
||||
|
||||
# https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580
|
||||
# Shrink GPU memory after execution
|
||||
run_options = ort.RunOptions()
|
||||
if cuda_is_available():
|
||||
gpu_mem_limit_mb = int(os.environ.get("OCR_GPU_MEM_LIMIT_MB", "2048"))
|
||||
arena_strategy = os.environ.get("OCR_ARENA_EXTEND_STRATEGY", "kNextPowerOfTwo")
|
||||
provider_device_id = 0 if device_id is None else device_id
|
||||
cuda_provider_options = {
|
||||
"device_id": provider_device_id, # Use specific GPU
|
||||
"gpu_mem_limit": max(gpu_mem_limit_mb, 0) * 1024 * 1024,
|
||||
"arena_extend_strategy": arena_strategy, # gpu memory allocation strategy
|
||||
}
|
||||
sess = ort.InferenceSession(
|
||||
model_file_path,
|
||||
options=options,
|
||||
providers=['CUDAExecutionProvider'],
|
||||
provider_options=[cuda_provider_options]
|
||||
)
|
||||
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(provider_device_id))
|
||||
logging.info(f"load_model {model_file_path} uses GPU (device {provider_device_id}, gpu_mem_limit={cuda_provider_options['gpu_mem_limit']}, arena_strategy={arena_strategy})")
|
||||
else:
|
||||
sess = ort.InferenceSession(
|
||||
model_file_path,
|
||||
options=options,
|
||||
providers=['CPUExecutionProvider'])
|
||||
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu")
|
||||
logging.info(f"load_model {model_file_path} uses CPU")
|
||||
loaded_model = (sess, run_options)
|
||||
loaded_models[model_cached_tag] = loaded_model
|
||||
return loaded_model
|
||||
|
||||
|
||||
class TextRecognizer:
|
||||
def __init__(self, model_dir, device_id: int | None = None):
|
||||
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
|
||||
self.rec_batch_num = 16
|
||||
postprocess_params = {
|
||||
'name': 'CTCLabelDecode',
|
||||
"character_dict_path": os.path.join(model_dir, "ocr.res"),
|
||||
"use_space_char": True
|
||||
}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.run_options = load_model(model_dir, 'rec', device_id)
|
||||
self.input_tensor = self.predictor.get_inputs()[0]
|
||||
|
||||
def resize_norm_img(self, img, max_wh_ratio):
|
||||
imgC, imgH, imgW = self.rec_image_shape
|
||||
|
||||
assert imgC == img.shape[2]
|
||||
imgW = int((imgH * max_wh_ratio))
|
||||
w = self.input_tensor.shape[3:][0]
|
||||
if isinstance(w, str):
|
||||
pass
|
||||
elif w is not None and w > 0:
|
||||
imgW = w
|
||||
h, w = img.shape[:2]
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
return padding_im
|
||||
|
||||
def resize_norm_img_vl(self, img, image_shape):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
img = img[:, :, ::-1] # bgr2rgb
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
return resized_image
|
||||
|
||||
def resize_norm_img_srn(self, img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
img_black = np.zeros((imgH, imgW))
|
||||
im_hei = img.shape[0]
|
||||
im_wid = img.shape[1]
|
||||
|
||||
if im_wid <= im_hei * 1:
|
||||
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||
elif im_wid <= im_hei * 2:
|
||||
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||
elif im_wid <= im_hei * 3:
|
||||
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||
else:
|
||||
img_new = cv2.resize(img, (imgW, imgH))
|
||||
|
||||
img_np = np.asarray(img_new)
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||
img_black[:, 0:img_np.shape[1]] = img_np
|
||||
img_black = img_black[:, :, np.newaxis]
|
||||
|
||||
row, col, c = img_black.shape
|
||||
c = 1
|
||||
|
||||
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||
|
||||
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||
|
||||
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
||||
(feature_dim, 1)).astype('int64')
|
||||
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
||||
(max_text_length, 1)).astype('int64')
|
||||
|
||||
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
||||
[-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias1 = np.tile(
|
||||
gsrm_slf_attn_bias1,
|
||||
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||
|
||||
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
||||
[-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias2 = np.tile(
|
||||
gsrm_slf_attn_bias2,
|
||||
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||
|
||||
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
||||
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
||||
|
||||
return [
|
||||
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2
|
||||
]
|
||||
|
||||
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
|
||||
norm_img = self.resize_norm_img_srn(img, image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
|
||||
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||
self.srn_other_inputs(image_shape, num_heads, max_text_length)
|
||||
|
||||
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
|
||||
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
|
||||
encoder_word_pos = encoder_word_pos.astype(np.int64)
|
||||
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
|
||||
|
||||
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2)
|
||||
|
||||
def resize_norm_img_sar(self, img, image_shape,
|
||||
width_downsample_ratio=0.25):
|
||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
valid_ratio = 1.0
|
||||
# make sure new_width is an integral multiple of width_divisor.
|
||||
width_divisor = int(1 / width_downsample_ratio)
|
||||
# resize
|
||||
ratio = w / float(h)
|
||||
resize_w = math.ceil(imgH * ratio)
|
||||
if resize_w % width_divisor != 0:
|
||||
resize_w = round(resize_w / width_divisor) * width_divisor
|
||||
if imgW_min is not None:
|
||||
resize_w = max(imgW_min, resize_w)
|
||||
if imgW_max is not None:
|
||||
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
|
||||
resize_w = min(imgW_max, resize_w)
|
||||
resized_image = cv2.resize(img, (resize_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
# norm
|
||||
if image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
resize_shape = resized_image.shape
|
||||
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
|
||||
padding_im[:, :, 0:resize_w] = resized_image
|
||||
pad_shape = padding_im.shape
|
||||
|
||||
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||
|
||||
def resize_norm_img_spin(self, img):
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
# return padding_im
|
||||
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
|
||||
img = np.array(img, np.float32)
|
||||
img = np.expand_dims(img, -1)
|
||||
img = img.transpose((2, 0, 1))
|
||||
mean = [127.5]
|
||||
std = [127.5]
|
||||
mean = np.array(mean, dtype=np.float32)
|
||||
std = np.array(std, dtype=np.float32)
|
||||
mean = np.float32(mean.reshape(1, -1))
|
||||
stdinv = 1 / np.float32(std.reshape(1, -1))
|
||||
img -= mean
|
||||
img *= stdinv
|
||||
return img
|
||||
|
||||
def resize_norm_img_svtr(self, img, image_shape):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
return resized_image
|
||||
|
||||
def resize_norm_img_abinet(self, img, image_shape):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image / 255.
|
||||
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
resized_image = (
|
||||
resized_image - mean[None, None, ...]) / std[None, None, ...]
|
||||
resized_image = resized_image.transpose((2, 0, 1))
|
||||
resized_image = resized_image.astype('float32')
|
||||
|
||||
return resized_image
|
||||
|
||||
def norm_img_can(self, img, image_shape):
|
||||
|
||||
img = cv2.cvtColor(
|
||||
img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
|
||||
|
||||
if self.rec_image_shape[0] == 1:
|
||||
h, w = img.shape
|
||||
_, imgH, imgW = self.rec_image_shape
|
||||
if h < imgH or w < imgW:
|
||||
padding_h = max(imgH - h, 0)
|
||||
padding_w = max(imgW - w, 0)
|
||||
img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
|
||||
'constant',
|
||||
constant_values=(255))
|
||||
img = img_padded
|
||||
|
||||
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
|
||||
img = img.astype('float32')
|
||||
|
||||
return img
|
||||
|
||||
def close(self):
|
||||
# close session and release manually
|
||||
logging.info('Close text recognizer.')
|
||||
if hasattr(self, "predictor"):
|
||||
del self.predictor
|
||||
gc.collect()
|
||||
|
||||
def __call__(self, img_list):
|
||||
img_num = len(img_list)
|
||||
# Calculate the aspect ratio of all text bars
|
||||
width_list = []
|
||||
for img in img_list:
|
||||
width_list.append(img.shape[1] / float(img.shape[0]))
|
||||
# Sorting can speed up the recognition process
|
||||
indices = np.argsort(np.array(width_list))
|
||||
rec_res = [['', 0.0]] * img_num
|
||||
batch_num = self.rec_batch_num
|
||||
st = time.time()
|
||||
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
imgC, imgH, imgW = self.rec_image_shape[:3]
|
||||
max_wh_ratio = imgW / imgH
|
||||
# max_wh_ratio = 0
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
h, w = img_list[indices[ino]].shape[0:2]
|
||||
wh_ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||
max_wh_ratio)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
|
||||
input_dict = {}
|
||||
input_dict[self.input_tensor.name] = norm_img_batch
|
||||
for i in range(100000):
|
||||
try:
|
||||
outputs = self.predictor.run(None, input_dict, self.run_options)
|
||||
break
|
||||
except Exception as e:
|
||||
if i >= 3:
|
||||
raise e
|
||||
time.sleep(5)
|
||||
preds = outputs[0]
|
||||
rec_result = self.postprocess_op(preds)
|
||||
for rno in range(len(rec_result)):
|
||||
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||
|
||||
return rec_res, time.time() - st
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
|
||||
class TextDetector:
|
||||
def __init__(self, model_dir, device_id: int | None = None):
|
||||
pre_process_list = [{
|
||||
'DetResizeForTest': {
|
||||
'limit_side_len': 960,
|
||||
'limit_type': "max",
|
||||
}
|
||||
}, {
|
||||
'NormalizeImage': {
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'scale': '1./255.',
|
||||
'order': 'hwc'
|
||||
}
|
||||
}, {
|
||||
'ToCHWImage': None
|
||||
}, {
|
||||
'KeepKeys': {
|
||||
'keep_keys': ['image', 'shape']
|
||||
}
|
||||
}]
|
||||
postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.5, "max_candidates": 1000,
|
||||
"unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}
|
||||
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.run_options = load_model(model_dir, 'det', device_id)
|
||||
self.input_tensor = self.predictor.get_inputs()[0]
|
||||
|
||||
img_h, img_w = self.input_tensor.shape[2:]
|
||||
if isinstance(img_h, str) or isinstance(img_w, str):
|
||||
pass
|
||||
elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
|
||||
pre_process_list[0] = {
|
||||
'DetResizeForTest': {
|
||||
'image_shape': [img_h, img_w]
|
||||
}
|
||||
}
|
||||
self.preprocess_op = create_operators(pre_process_list)
|
||||
|
||||
def order_points_clockwise(self, pts):
|
||||
rect = np.zeros((4, 2), dtype="float32")
|
||||
s = pts.sum(axis=1)
|
||||
rect[0] = pts[np.argmin(s)]
|
||||
rect[2] = pts[np.argmax(s)]
|
||||
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
|
||||
diff = np.diff(np.array(tmp), axis=1)
|
||||
rect[1] = tmp[np.argmin(diff)]
|
||||
rect[3] = tmp[np.argmax(diff)]
|
||||
return rect
|
||||
|
||||
def clip_det_res(self, points, img_height, img_width):
|
||||
for pno in range(points.shape[0]):
|
||||
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
||||
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
||||
return points
|
||||
|
||||
def filter_tag_det_res(self, dt_boxes, image_shape):
|
||||
img_height, img_width = image_shape[0:2]
|
||||
dt_boxes_new = []
|
||||
for box in dt_boxes:
|
||||
if isinstance(box, list):
|
||||
box = np.array(box)
|
||||
box = self.order_points_clockwise(box)
|
||||
box = self.clip_det_res(box, img_height, img_width)
|
||||
rect_width = int(np.linalg.norm(box[0] - box[1]))
|
||||
rect_height = int(np.linalg.norm(box[0] - box[3]))
|
||||
if rect_width <= 3 or rect_height <= 3:
|
||||
continue
|
||||
dt_boxes_new.append(box)
|
||||
dt_boxes = np.array(dt_boxes_new)
|
||||
return dt_boxes
|
||||
|
||||
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
|
||||
img_height, img_width = image_shape[0:2]
|
||||
dt_boxes_new = []
|
||||
for box in dt_boxes:
|
||||
if isinstance(box, list):
|
||||
box = np.array(box)
|
||||
box = self.clip_det_res(box, img_height, img_width)
|
||||
dt_boxes_new.append(box)
|
||||
dt_boxes = np.array(dt_boxes_new)
|
||||
return dt_boxes
|
||||
|
||||
def close(self):
|
||||
logging.info("Close text detector.")
|
||||
if hasattr(self, "predictor"):
|
||||
del self.predictor
|
||||
gc.collect()
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
data = {'image': img}
|
||||
|
||||
st = time.time()
|
||||
data = transform(data, self.preprocess_op)
|
||||
img, shape_list = data
|
||||
if img is None:
|
||||
return None, 0
|
||||
img = np.expand_dims(img, axis=0)
|
||||
shape_list = np.expand_dims(shape_list, axis=0)
|
||||
img = img.copy()
|
||||
input_dict = {}
|
||||
input_dict[self.input_tensor.name] = img
|
||||
for i in range(100000):
|
||||
try:
|
||||
outputs = self.predictor.run(None, input_dict, self.run_options)
|
||||
break
|
||||
except Exception as e:
|
||||
if i >= 3:
|
||||
raise e
|
||||
time.sleep(5)
|
||||
|
||||
post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
|
||||
dt_boxes = post_result[0]['points']
|
||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||
|
||||
return dt_boxes, time.time() - st
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
|
||||
class OCR:
|
||||
def __init__(self, model_dir=None):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not model_dir:
|
||||
try:
|
||||
model_dir = os.path.join(
|
||||
get_project_base_directory(),
|
||||
"res/deepdoc")
|
||||
|
||||
# Append muti-gpus task to the list
|
||||
if settings.PARALLEL_DEVICES > 0:
|
||||
self.text_detector = []
|
||||
self.text_recognizer = []
|
||||
for device_id in range(settings.PARALLEL_DEVICES):
|
||||
self.text_detector.append(TextDetector(model_dir, device_id))
|
||||
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
||||
else:
|
||||
self.text_detector = [TextDetector(model_dir)]
|
||||
self.text_recognizer = [TextRecognizer(model_dir)]
|
||||
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(get_project_base_directory(), "res/deepdoc"),
|
||||
local_dir_use_symlinks=False)
|
||||
|
||||
if settings.PARALLEL_DEVICES > 0:
|
||||
self.text_detector = []
|
||||
self.text_recognizer = []
|
||||
for device_id in range(settings.PARALLEL_DEVICES):
|
||||
self.text_detector.append(TextDetector(model_dir, device_id))
|
||||
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
||||
else:
|
||||
self.text_detector = [TextDetector(model_dir)]
|
||||
self.text_recognizer = [TextRecognizer(model_dir)]
|
||||
|
||||
self.drop_score = 0.5
|
||||
self.crop_image_res_index = 0
|
||||
|
||||
def get_rotate_crop_image(self, img, points):
|
||||
'''
|
||||
img_height, img_width = img.shape[0:2]
|
||||
left = int(np.min(points[:, 0]))
|
||||
right = int(np.max(points[:, 0]))
|
||||
top = int(np.min(points[:, 1]))
|
||||
bottom = int(np.max(points[:, 1]))
|
||||
img_crop = img[top:bottom, left:right, :].copy()
|
||||
points[:, 0] = points[:, 0] - left
|
||||
points[:, 1] = points[:, 1] - top
|
||||
'''
|
||||
assert len(points) == 4, "shape of points must be 4*2"
|
||||
img_crop_width = int(
|
||||
max(
|
||||
np.linalg.norm(points[0] - points[1]),
|
||||
np.linalg.norm(points[2] - points[3])))
|
||||
img_crop_height = int(
|
||||
max(
|
||||
np.linalg.norm(points[0] - points[3]),
|
||||
np.linalg.norm(points[1] - points[2])))
|
||||
pts_std = np.float32([[0, 0], [img_crop_width, 0],
|
||||
[img_crop_width, img_crop_height],
|
||||
[0, img_crop_height]])
|
||||
M = cv2.getPerspectiveTransform(points, pts_std)
|
||||
dst_img = cv2.warpPerspective(
|
||||
img,
|
||||
M, (img_crop_width, img_crop_height),
|
||||
borderMode=cv2.BORDER_REPLICATE,
|
||||
flags=cv2.INTER_CUBIC)
|
||||
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
||||
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
||||
# Try original orientation
|
||||
rec_result = self.text_recognizer[0]([dst_img])
|
||||
text, score = rec_result[0][0]
|
||||
best_score = score
|
||||
best_img = dst_img
|
||||
|
||||
# Try clockwise 90° rotation
|
||||
rotated_cw = np.rot90(dst_img, k=3)
|
||||
rec_result = self.text_recognizer[0]([rotated_cw])
|
||||
rotated_cw_text, rotated_cw_score = rec_result[0][0]
|
||||
if rotated_cw_score > best_score:
|
||||
best_score = rotated_cw_score
|
||||
best_img = rotated_cw
|
||||
|
||||
# Try counter-clockwise 90° rotation
|
||||
rotated_ccw = np.rot90(dst_img, k=1)
|
||||
rec_result = self.text_recognizer[0]([rotated_ccw])
|
||||
rotated_ccw_text, rotated_ccw_score = rec_result[0][0]
|
||||
if rotated_ccw_score > best_score:
|
||||
best_img = rotated_ccw
|
||||
|
||||
# Use the best image
|
||||
dst_img = best_img
|
||||
return dst_img
|
||||
|
||||
def sorted_boxes(self, dt_boxes):
|
||||
"""
|
||||
Sort text boxes in order from top to bottom, left to right
|
||||
args:
|
||||
dt_boxes(array):detected text boxes with shape [4, 2]
|
||||
return:
|
||||
sorted boxes(array) with shape [4, 2]
|
||||
"""
|
||||
num_boxes = dt_boxes.shape[0]
|
||||
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
||||
_boxes = list(sorted_boxes)
|
||||
|
||||
for i in range(num_boxes - 1):
|
||||
for j in range(i, -1, -1):
|
||||
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
||||
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
||||
tmp = _boxes[j]
|
||||
_boxes[j] = _boxes[j + 1]
|
||||
_boxes[j + 1] = tmp
|
||||
else:
|
||||
break
|
||||
return _boxes
|
||||
|
||||
def detect(self, img, device_id: int | None = None):
|
||||
if device_id is None:
|
||||
device_id = 0
|
||||
|
||||
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
||||
|
||||
if img is None:
|
||||
return None, None, time_dict
|
||||
|
||||
start = time.time()
|
||||
dt_boxes, elapse = self.text_detector[device_id](img)
|
||||
time_dict['det'] = elapse
|
||||
|
||||
if dt_boxes is None:
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
return None, None, time_dict
|
||||
|
||||
return zip(self.sorted_boxes(dt_boxes), [
|
||||
("", 0) for _ in range(len(dt_boxes))])
|
||||
|
||||
def recognize(self, ori_im, box, device_id: int | None = None):
|
||||
if device_id is None:
|
||||
device_id = 0
|
||||
|
||||
img_crop = self.get_rotate_crop_image(ori_im, box)
|
||||
|
||||
rec_res, elapse = self.text_recognizer[device_id]([img_crop])
|
||||
text, score = rec_res[0]
|
||||
if score < self.drop_score:
|
||||
return ""
|
||||
return text
|
||||
|
||||
def recognize_batch(self, img_list, device_id: int | None = None):
|
||||
if device_id is None:
|
||||
device_id = 0
|
||||
rec_res, elapse = self.text_recognizer[device_id](img_list)
|
||||
texts = []
|
||||
for i in range(len(rec_res)):
|
||||
text, score = rec_res[i]
|
||||
if score < self.drop_score:
|
||||
text = ""
|
||||
texts.append(text)
|
||||
return texts
|
||||
|
||||
def __call__(self, img, device_id = 0, cls=True):
|
||||
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
||||
if device_id is None:
|
||||
device_id = 0
|
||||
|
||||
if img is None:
|
||||
return None, None, time_dict
|
||||
|
||||
start = time.time()
|
||||
ori_im = img.copy()
|
||||
dt_boxes, elapse = self.text_detector[device_id](img)
|
||||
time_dict['det'] = elapse
|
||||
|
||||
if dt_boxes is None:
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
return None, None, time_dict
|
||||
|
||||
img_crop_list = []
|
||||
|
||||
dt_boxes = self.sorted_boxes(dt_boxes)
|
||||
|
||||
for bno in range(len(dt_boxes)):
|
||||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
||||
img_crop_list.append(img_crop)
|
||||
|
||||
rec_res, elapse = self.text_recognizer[device_id](img_crop_list)
|
||||
|
||||
time_dict['rec'] = elapse
|
||||
|
||||
filter_boxes, filter_rec_res = [], []
|
||||
for box, rec_result in zip(dt_boxes, rec_res):
|
||||
text, score = rec_result
|
||||
if score >= self.drop_score:
|
||||
filter_boxes.append(box)
|
||||
filter_rec_res.append(rec_result)
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
|
||||
# for bno in range(len(img_crop_list)):
|
||||
# print(f"{bno}, {rec_res[bno]}")
|
||||
|
||||
return list(zip([a.tolist() for a in filter_boxes], filter_rec_res))
|
||||
709
api/app/core/rag/deepdoc/vision/operators.py
Normal file
709
api/app/core/rag/deepdoc/vision/operators.py
Normal file
@@ -0,0 +1,709 @@
|
||||
import logging
|
||||
import sys
|
||||
import six
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class DecodeImage:
|
||||
""" decode image """
|
||||
|
||||
def __init__(self,
|
||||
img_mode='RGB',
|
||||
channel_first=False,
|
||||
ignore_orientation=False,
|
||||
**kwargs):
|
||||
self.img_mode = img_mode
|
||||
self.channel_first = channel_first
|
||||
self.ignore_orientation = ignore_orientation
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if six.PY2:
|
||||
assert isinstance(img, str) and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
else:
|
||||
assert isinstance(img, bytes) and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
img = np.frombuffer(img, dtype='uint8')
|
||||
if self.ignore_orientation:
|
||||
img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
|
||||
cv2.IMREAD_COLOR)
|
||||
else:
|
||||
img = cv2.imdecode(img, 1)
|
||||
if img is None:
|
||||
return None
|
||||
if self.img_mode == 'GRAY':
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif self.img_mode == 'RGB':
|
||||
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
|
||||
img.shape)
|
||||
img = img[:, :, ::-1]
|
||||
|
||||
if self.channel_first:
|
||||
img = img.transpose((2, 0, 1))
|
||||
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class StandardizeImag:
|
||||
"""normalize image
|
||||
Args:
|
||||
mean (list): im - mean
|
||||
std (list): im / std
|
||||
is_scale (bool): whether need im / 255
|
||||
norm_type (str): type in ['mean_std', 'none']
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.is_scale = is_scale
|
||||
self.norm_type = norm_type
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.astype(np.float32, copy=False)
|
||||
if self.is_scale:
|
||||
scale = 1.0 / 255.0
|
||||
im *= scale
|
||||
|
||||
if self.norm_type == 'mean_std':
|
||||
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
|
||||
std = np.array(self.std)[np.newaxis, np.newaxis, :]
|
||||
im -= mean
|
||||
im /= std
|
||||
return im, im_info
|
||||
|
||||
|
||||
class NormalizeImage:
|
||||
""" normalize image such as subtract mean, divide std
|
||||
"""
|
||||
|
||||
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
|
||||
if isinstance(scale, str):
|
||||
scale = eval(scale)
|
||||
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
||||
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
||||
std = std if std is not None else [0.229, 0.224, 0.225]
|
||||
|
||||
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
|
||||
self.mean = np.array(mean).reshape(shape).astype('float32')
|
||||
self.std = np.array(std).reshape(shape).astype('float32')
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
from PIL import Image
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
assert isinstance(img,
|
||||
np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||
data['image'] = (
|
||||
img.astype('float32') * self.scale - self.mean) / self.std
|
||||
return data
|
||||
|
||||
|
||||
class ToCHWImage:
|
||||
""" convert hwc image to chw image
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
from PIL import Image
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
data['image'] = img.transpose((2, 0, 1))
|
||||
return data
|
||||
|
||||
|
||||
class KeepKeys:
|
||||
def __init__(self, keep_keys, **kwargs):
|
||||
self.keep_keys = keep_keys
|
||||
|
||||
def __call__(self, data):
|
||||
data_list = []
|
||||
for key in self.keep_keys:
|
||||
data_list.append(data[key])
|
||||
return data_list
|
||||
|
||||
|
||||
class Pad:
|
||||
def __init__(self, size=None, size_div=32, **kwargs):
|
||||
if size is not None and not isinstance(size, (int, list, tuple)):
|
||||
raise TypeError("Type of target_size is invalid. Now is {}".format(
|
||||
type(size)))
|
||||
if isinstance(size, int):
|
||||
size = [size, size]
|
||||
self.size = size
|
||||
self.size_div = size_div
|
||||
|
||||
def __call__(self, data):
|
||||
|
||||
img = data['image']
|
||||
img_h, img_w = img.shape[0], img.shape[1]
|
||||
if self.size:
|
||||
resize_h2, resize_w2 = self.size
|
||||
assert (
|
||||
img_h < resize_h2 and img_w < resize_w2
|
||||
), '(h, w) of target size should be greater than (img_h, img_w)'
|
||||
else:
|
||||
resize_h2 = max(
|
||||
int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
|
||||
self.size_div)
|
||||
resize_w2 = max(
|
||||
int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
|
||||
self.size_div)
|
||||
img = cv2.copyMakeBorder(
|
||||
img,
|
||||
0,
|
||||
resize_h2 - img_h,
|
||||
0,
|
||||
resize_w2 - img_w,
|
||||
cv2.BORDER_CONSTANT,
|
||||
value=0)
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class LinearResize:
|
||||
"""resize image by target_size and max_size
|
||||
Args:
|
||||
target_size (int): the target size of image
|
||||
keep_ratio (bool): whether keep_ratio or not, default true
|
||||
interp (int): method of resize
|
||||
"""
|
||||
|
||||
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
|
||||
if isinstance(target_size, int):
|
||||
target_size = [target_size, target_size]
|
||||
self.target_size = target_size
|
||||
self.keep_ratio = keep_ratio
|
||||
self.interp = interp
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
assert len(self.target_size) == 2
|
||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||
_im_channel = im.shape[2]
|
||||
im_scale_y, im_scale_x = self.generate_scale(im)
|
||||
im = cv2.resize(
|
||||
im,
|
||||
None,
|
||||
None,
|
||||
fx=im_scale_x,
|
||||
fy=im_scale_y,
|
||||
interpolation=self.interp)
|
||||
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
|
||||
im_info['scale_factor'] = np.array(
|
||||
[im_scale_y, im_scale_x]).astype('float32')
|
||||
return im, im_info
|
||||
|
||||
def generate_scale(self, im):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
Returns:
|
||||
im_scale_x: the resize ratio of X
|
||||
im_scale_y: the resize ratio of Y
|
||||
"""
|
||||
origin_shape = im.shape[:2]
|
||||
_im_c = im.shape[2]
|
||||
if self.keep_ratio:
|
||||
im_size_min = np.min(origin_shape)
|
||||
im_size_max = np.max(origin_shape)
|
||||
target_size_min = np.min(self.target_size)
|
||||
target_size_max = np.max(self.target_size)
|
||||
im_scale = float(target_size_min) / float(im_size_min)
|
||||
if np.round(im_scale * im_size_max) > target_size_max:
|
||||
im_scale = float(target_size_max) / float(im_size_max)
|
||||
im_scale_x = im_scale
|
||||
im_scale_y = im_scale
|
||||
else:
|
||||
resize_h, resize_w = self.target_size
|
||||
im_scale_y = resize_h / float(origin_shape[0])
|
||||
im_scale_x = resize_w / float(origin_shape[1])
|
||||
return im_scale_y, im_scale_x
|
||||
|
||||
|
||||
class Resize:
|
||||
def __init__(self, size=(640, 640), **kwargs):
|
||||
self.size = size
|
||||
|
||||
def resize_image(self, img):
|
||||
resize_h, resize_w = self.size
|
||||
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
||||
ratio_h = float(resize_h) / ori_h
|
||||
ratio_w = float(resize_w) / ori_w
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if 'polys' in data:
|
||||
text_polys = data['polys']
|
||||
|
||||
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
|
||||
if 'polys' in data:
|
||||
new_boxes = []
|
||||
for box in text_polys:
|
||||
new_box = []
|
||||
for cord in box:
|
||||
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
|
||||
new_boxes.append(new_box)
|
||||
data['polys'] = np.array(new_boxes, dtype=np.float32)
|
||||
data['image'] = img_resize
|
||||
return data
|
||||
|
||||
|
||||
class DetResizeForTest:
|
||||
def __init__(self, **kwargs):
|
||||
super(DetResizeForTest, self).__init__()
|
||||
self.resize_type = 0
|
||||
self.keep_ratio = False
|
||||
if 'image_shape' in kwargs:
|
||||
self.image_shape = kwargs['image_shape']
|
||||
self.resize_type = 1
|
||||
if 'keep_ratio' in kwargs:
|
||||
self.keep_ratio = kwargs['keep_ratio']
|
||||
elif 'limit_side_len' in kwargs:
|
||||
self.limit_side_len = kwargs['limit_side_len']
|
||||
self.limit_type = kwargs.get('limit_type', 'min')
|
||||
elif 'resize_long' in kwargs:
|
||||
self.resize_type = 2
|
||||
self.resize_long = kwargs.get('resize_long', 960)
|
||||
else:
|
||||
self.limit_side_len = 736
|
||||
self.limit_type = 'min'
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
src_h, src_w, _ = img.shape
|
||||
if sum([src_h, src_w]) < 64:
|
||||
img = self.image_padding(img)
|
||||
|
||||
if self.resize_type == 0:
|
||||
# img, shape = self.resize_image_type0(img)
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
|
||||
elif self.resize_type == 2:
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
|
||||
else:
|
||||
# img, shape = self.resize_image_type1(img)
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
|
||||
data['image'] = img
|
||||
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
||||
return data
|
||||
|
||||
def image_padding(self, im, value=0):
|
||||
h, w, c = im.shape
|
||||
im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
|
||||
im_pad[:h, :w, :] = im
|
||||
return im_pad
|
||||
|
||||
def resize_image_type1(self, img):
|
||||
resize_h, resize_w = self.image_shape
|
||||
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
||||
if self.keep_ratio is True:
|
||||
resize_w = ori_w * resize_h / ori_h
|
||||
N = math.ceil(resize_w / 32)
|
||||
resize_w = N * 32
|
||||
ratio_h = float(resize_h) / ori_h
|
||||
ratio_w = float(resize_w) / ori_w
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
# return img, np.array([ori_h, ori_w])
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type0(self, img):
|
||||
"""
|
||||
resize image to a size multiple of 32 which is required by the network
|
||||
args:
|
||||
img(array): array with shape [h, w, c]
|
||||
return(tuple):
|
||||
img, (ratio_h, ratio_w)
|
||||
"""
|
||||
limit_side_len = self.limit_side_len
|
||||
h, w, c = img.shape
|
||||
|
||||
# limit the max side
|
||||
if self.limit_type == 'max':
|
||||
if max(h, w) > limit_side_len:
|
||||
if h > w:
|
||||
ratio = float(limit_side_len) / h
|
||||
else:
|
||||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
elif self.limit_type == 'min':
|
||||
if min(h, w) < limit_side_len:
|
||||
if h < w:
|
||||
ratio = float(limit_side_len) / h
|
||||
else:
|
||||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
elif self.limit_type == 'resize_long':
|
||||
ratio = float(limit_side_len) / max(h, w)
|
||||
else:
|
||||
raise Exception('not support limit type, image ')
|
||||
resize_h = int(h * ratio)
|
||||
resize_w = int(w * ratio)
|
||||
|
||||
resize_h = max(int(round(resize_h / 32) * 32), 32)
|
||||
resize_w = max(int(round(resize_w / 32) * 32), 32)
|
||||
|
||||
try:
|
||||
if int(resize_w) <= 0 or int(resize_h) <= 0:
|
||||
return None, (None, None)
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
except BaseException:
|
||||
logging.exception("{} {} {}".format(img.shape, resize_w, resize_h))
|
||||
sys.exit(0)
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type2(self, img):
|
||||
h, w, _ = img.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
if resize_h > resize_w:
|
||||
ratio = float(self.resize_long) / resize_h
|
||||
else:
|
||||
ratio = float(self.resize_long) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
|
||||
class E2EResizeForTest:
|
||||
def __init__(self, **kwargs):
|
||||
super(E2EResizeForTest, self).__init__()
|
||||
self.max_side_len = kwargs['max_side_len']
|
||||
self.valid_set = kwargs['valid_set']
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
src_h, src_w, _ = img.shape
|
||||
if self.valid_set == 'totaltext':
|
||||
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
|
||||
img, max_side_len=self.max_side_len)
|
||||
else:
|
||||
im_resized, (ratio_h, ratio_w) = self.resize_image(
|
||||
img, max_side_len=self.max_side_len)
|
||||
data['image'] = im_resized
|
||||
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
||||
return data
|
||||
|
||||
def resize_image_for_totaltext(self, im, max_side_len=512):
|
||||
h, w, _ = im.shape
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
ratio = 1.25
|
||||
if h * ratio > max_side_len:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
def resize_image(self, im, max_side_len=512):
|
||||
"""
|
||||
resize image to a size multiple of max_stride which is required by the network
|
||||
:param im: the resized image
|
||||
:param max_side_len: limit of max image size to avoid out of memory in gpu
|
||||
:return: the resized image and the resize ratio
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
# Fix the longer side
|
||||
if resize_h > resize_w:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
else:
|
||||
ratio = float(max_side_len) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
|
||||
class KieResize:
|
||||
def __init__(self, **kwargs):
|
||||
super(KieResize, self).__init__()
|
||||
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
|
||||
'img_scale'][1]
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
points = data['points']
|
||||
src_h, src_w, _ = img.shape
|
||||
im_resized, scale_factor, [ratio_h, ratio_w
|
||||
], [new_h, new_w] = self.resize_image(img)
|
||||
resize_points = self.resize_boxes(img, points, scale_factor)
|
||||
data['ori_image'] = img
|
||||
data['ori_boxes'] = points
|
||||
data['points'] = resize_points
|
||||
data['image'] = im_resized
|
||||
data['shape'] = np.array([new_h, new_w])
|
||||
return data
|
||||
|
||||
def resize_image(self, img):
|
||||
norm_img = np.zeros([1024, 1024, 3], dtype='float32')
|
||||
scale = [512, 1024]
|
||||
h, w = img.shape[:2]
|
||||
max_long_edge = max(scale)
|
||||
max_short_edge = min(scale)
|
||||
scale_factor = min(max_long_edge / max(h, w),
|
||||
max_short_edge / min(h, w))
|
||||
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
|
||||
scale_factor) + 0.5)
|
||||
max_stride = 32
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(img, (resize_w, resize_h))
|
||||
new_h, new_w = im.shape[:2]
|
||||
w_scale = new_w / w
|
||||
h_scale = new_h / h
|
||||
scale_factor = np.array(
|
||||
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
|
||||
norm_img[:new_h, :new_w, :] = im
|
||||
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
|
||||
|
||||
def resize_boxes(self, im, points, scale_factor):
|
||||
points = points * scale_factor
|
||||
img_shape = im.shape[:2]
|
||||
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
|
||||
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
|
||||
return points
|
||||
|
||||
|
||||
class SRResize:
|
||||
def __init__(self,
|
||||
imgH=32,
|
||||
imgW=128,
|
||||
down_sample_scale=4,
|
||||
keep_ratio=False,
|
||||
min_ratio=1,
|
||||
mask=False,
|
||||
infer_mode=False,
|
||||
**kwargs):
|
||||
self.imgH = imgH
|
||||
self.imgW = imgW
|
||||
self.keep_ratio = keep_ratio
|
||||
self.min_ratio = min_ratio
|
||||
self.down_sample_scale = down_sample_scale
|
||||
self.mask = mask
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def __call__(self, data):
|
||||
imgH = self.imgH
|
||||
imgW = self.imgW
|
||||
images_lr = data["image_lr"]
|
||||
transform2 = ResizeNormalize(
|
||||
(imgW // self.down_sample_scale, imgH // self.down_sample_scale))
|
||||
images_lr = transform2(images_lr)
|
||||
data["img_lr"] = images_lr
|
||||
if self.infer_mode:
|
||||
return data
|
||||
|
||||
images_HR = data["image_hr"]
|
||||
_label_strs = data["label"]
|
||||
transform = ResizeNormalize((imgW, imgH))
|
||||
images_HR = transform(images_HR)
|
||||
data["img_hr"] = images_HR
|
||||
return data
|
||||
|
||||
|
||||
class ResizeNormalize:
|
||||
def __init__(self, size, interpolation=Image.BICUBIC):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img):
|
||||
img = img.resize(self.size, self.interpolation)
|
||||
img_numpy = np.array(img).astype("float32")
|
||||
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
|
||||
return img_numpy
|
||||
|
||||
|
||||
class GrayImageChannelFormat:
|
||||
"""
|
||||
format gray scale image's channel: (3,h,w) -> (1,h,w)
|
||||
Args:
|
||||
inverse: inverse gray image
|
||||
"""
|
||||
|
||||
def __init__(self, inverse=False, **kwargs):
|
||||
self.inverse = inverse
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
img_expanded = np.expand_dims(img_single_channel, 0)
|
||||
|
||||
if self.inverse:
|
||||
data['image'] = np.abs(img_expanded - 1)
|
||||
else:
|
||||
data['image'] = img_expanded
|
||||
|
||||
data['src_image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class Permute:
|
||||
"""permute image
|
||||
Args:
|
||||
to_bgr (bool): whether convert RGB to BGR
|
||||
channel_first (bool): whether convert HWC to CHW
|
||||
"""
|
||||
|
||||
def __init__(self, ):
|
||||
super(Permute, self).__init__()
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.transpose((2, 0, 1)).copy()
|
||||
return im, im_info
|
||||
|
||||
|
||||
class PadStride:
|
||||
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
|
||||
Args:
|
||||
stride (bool): model with FPN need image shape % stride == 0
|
||||
"""
|
||||
|
||||
def __init__(self, stride=0):
|
||||
self.coarsest_stride = stride
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
coarsest_stride = self.coarsest_stride
|
||||
if coarsest_stride <= 0:
|
||||
return im, im_info
|
||||
im_c, im_h, im_w = im.shape
|
||||
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
|
||||
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
|
||||
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
|
||||
padding_im[:, :im_h, :im_w] = im
|
||||
return padding_im, im_info
|
||||
|
||||
|
||||
def decode_image(im_file, im_info):
|
||||
"""read rgb image
|
||||
Args:
|
||||
im_file (str|np.ndarray): input can be image path or np.ndarray
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
if isinstance(im_file, str):
|
||||
with open(im_file, 'rb') as f:
|
||||
im_read = f.read()
|
||||
data = np.frombuffer(im_read, dtype='uint8')
|
||||
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
|
||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||||
else:
|
||||
im = im_file
|
||||
im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
|
||||
im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
|
||||
return im, im_info
|
||||
|
||||
|
||||
def preprocess(im, preprocess_ops):
|
||||
# process image by preprocess_ops
|
||||
im_info = {
|
||||
'scale_factor': np.array(
|
||||
[1., 1.], dtype=np.float32),
|
||||
'im_shape': None,
|
||||
}
|
||||
im, im_info = decode_image(im, im_info)
|
||||
for operator in preprocess_ops:
|
||||
im, im_info = operator(im, im_info)
|
||||
return im, im_info
|
||||
|
||||
|
||||
def nms(bboxes, scores, iou_thresh):
|
||||
import numpy as np
|
||||
x1 = bboxes[:, 0]
|
||||
y1 = bboxes[:, 1]
|
||||
x2 = bboxes[:, 2]
|
||||
y2 = bboxes[:, 3]
|
||||
areas = (y2 - y1) * (x2 - x1)
|
||||
|
||||
indices = []
|
||||
index = scores.argsort()[::-1]
|
||||
while index.size > 0:
|
||||
i = index[0]
|
||||
indices.append(i)
|
||||
x11 = np.maximum(x1[i], x1[index[1:]])
|
||||
y11 = np.maximum(y1[i], y1[index[1:]])
|
||||
x22 = np.minimum(x2[i], x2[index[1:]])
|
||||
y22 = np.minimum(y2[i], y2[index[1:]])
|
||||
w = np.maximum(0, x22 - x11 + 1)
|
||||
h = np.maximum(0, y22 - y11 + 1)
|
||||
overlaps = w * h
|
||||
ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
|
||||
idx = np.where(ious <= iou_thresh)[0]
|
||||
index = index[idx + 1]
|
||||
return indices
|
||||
354
api/app/core/rag/deepdoc/vision/postprocess.py
Normal file
354
api/app/core/rag/deepdoc/vision/postprocess.py
Normal file
@@ -0,0 +1,354 @@
|
||||
import copy
|
||||
import re
|
||||
import numpy as np
|
||||
import cv2
|
||||
from shapely.geometry import Polygon
|
||||
import pyclipper
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
support_dict = {'DBPostProcess': DBPostProcess, 'CTCLabelDecode': CTCLabelDecode}
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
if module_name == "None":
|
||||
return
|
||||
if global_config is not None:
|
||||
config.update(global_config)
|
||||
module_class = support_dict.get(module_name)
|
||||
if module_class is None:
|
||||
raise ValueError(
|
||||
'post process only support {}'.format(list(support_dict)))
|
||||
return module_class(**config)
|
||||
|
||||
|
||||
class DBPostProcess:
|
||||
"""
|
||||
The post process for Differentiable Binarization (DB).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
thresh=0.3,
|
||||
box_thresh=0.7,
|
||||
max_candidates=1000,
|
||||
unclip_ratio=2.0,
|
||||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
box_type='quad',
|
||||
**kwargs):
|
||||
self.thresh = thresh
|
||||
self.box_thresh = box_thresh
|
||||
self.max_candidates = max_candidates
|
||||
self.unclip_ratio = unclip_ratio
|
||||
self.min_size = 3
|
||||
self.score_mode = score_mode
|
||||
self.box_type = box_type
|
||||
assert score_mode in [
|
||||
"slow", "fast"
|
||||
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
|
||||
|
||||
self.dilation_kernel = None if not use_dilation else np.array(
|
||||
[[1, 1], [1, 1]])
|
||||
|
||||
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||
'''
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
'''
|
||||
|
||||
bitmap = _bitmap
|
||||
height, width = bitmap.shape
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
|
||||
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
|
||||
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
for contour in contours[:self.max_candidates]:
|
||||
epsilon = 0.002 * cv2.arcLength(contour, True)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
points = approx.reshape((-1, 2))
|
||||
if points.shape[0] < 4:
|
||||
continue
|
||||
|
||||
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||
if self.box_thresh > score:
|
||||
continue
|
||||
|
||||
if points.shape[0] > 2:
|
||||
box = self.unclip(points, self.unclip_ratio)
|
||||
if len(box) > 1:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
box = box.reshape(-1, 2)
|
||||
|
||||
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
|
||||
if sside < self.min_size + 2:
|
||||
continue
|
||||
|
||||
box = np.array(box)
|
||||
box[:, 0] = np.clip(
|
||||
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||
box[:, 1] = np.clip(
|
||||
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||
boxes.append(box.tolist())
|
||||
scores.append(score)
|
||||
return boxes, scores
|
||||
|
||||
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||
'''
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
'''
|
||||
|
||||
bitmap = _bitmap
|
||||
height, width = bitmap.shape
|
||||
|
||||
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
|
||||
cv2.CHAIN_APPROX_SIMPLE)
|
||||
if len(outs) == 3:
|
||||
_img, contours, _ = outs[0], outs[1], outs[2]
|
||||
elif len(outs) == 2:
|
||||
contours, _ = outs[0], outs[1]
|
||||
|
||||
num_contours = min(len(contours), self.max_candidates)
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
for index in range(num_contours):
|
||||
contour = contours[index]
|
||||
points, sside = self.get_mini_boxes(contour)
|
||||
if sside < self.min_size:
|
||||
continue
|
||||
points = np.array(points)
|
||||
if self.score_mode == "fast":
|
||||
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||
else:
|
||||
score = self.box_score_slow(pred, contour)
|
||||
if self.box_thresh > score:
|
||||
continue
|
||||
|
||||
box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
|
||||
box, sside = self.get_mini_boxes(box)
|
||||
if sside < self.min_size + 2:
|
||||
continue
|
||||
box = np.array(box)
|
||||
|
||||
box[:, 0] = np.clip(
|
||||
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||
box[:, 1] = np.clip(
|
||||
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||
boxes.append(box.astype("int32"))
|
||||
scores.append(score)
|
||||
return np.array(boxes, dtype="int32"), scores
|
||||
|
||||
def unclip(self, box, unclip_ratio):
|
||||
poly = Polygon(box)
|
||||
distance = poly.area * unclip_ratio / poly.length
|
||||
offset = pyclipper.PyclipperOffset()
|
||||
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
||||
expanded = np.array(offset.Execute(distance))
|
||||
return expanded
|
||||
|
||||
def get_mini_boxes(self, contour):
|
||||
bounding_box = cv2.minAreaRect(contour)
|
||||
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
||||
|
||||
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
|
||||
if points[1][1] > points[0][1]:
|
||||
index_1 = 0
|
||||
index_4 = 1
|
||||
else:
|
||||
index_1 = 1
|
||||
index_4 = 0
|
||||
if points[3][1] > points[2][1]:
|
||||
index_2 = 2
|
||||
index_3 = 3
|
||||
else:
|
||||
index_2 = 3
|
||||
index_3 = 2
|
||||
|
||||
box = [
|
||||
points[index_1], points[index_2], points[index_3], points[index_4]
|
||||
]
|
||||
return box, min(bounding_box[1])
|
||||
|
||||
def box_score_fast(self, bitmap, _box):
|
||||
'''
|
||||
box_score_fast: use bbox mean score as the mean score
|
||||
'''
|
||||
h, w = bitmap.shape[:2]
|
||||
box = _box.copy()
|
||||
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
|
||||
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
|
||||
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
|
||||
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
box[:, 0] = box[:, 0] - xmin
|
||||
box[:, 1] = box[:, 1] - ymin
|
||||
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def box_score_slow(self, bitmap, contour):
|
||||
'''
|
||||
box_score_slow: use polyon mean score as the mean score
|
||||
'''
|
||||
h, w = bitmap.shape[:2]
|
||||
contour = contour.copy()
|
||||
contour = np.reshape(contour, (-1, 2))
|
||||
|
||||
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
|
||||
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
|
||||
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
|
||||
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
|
||||
contour[:, 0] = contour[:, 0] - xmin
|
||||
contour[:, 1] = contour[:, 1] - ymin
|
||||
|
||||
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
pred = outs_dict['maps']
|
||||
if not isinstance(pred, np.ndarray):
|
||||
pred = pred.numpy()
|
||||
pred = pred[:, 0, :, :]
|
||||
segmentation = pred > self.thresh
|
||||
|
||||
boxes_batch = []
|
||||
for batch_index in range(pred.shape[0]):
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
|
||||
if self.dilation_kernel is not None:
|
||||
mask = cv2.dilate(
|
||||
np.array(segmentation[batch_index]).astype(np.uint8),
|
||||
self.dilation_kernel)
|
||||
else:
|
||||
mask = segmentation[batch_index]
|
||||
if self.box_type == 'poly':
|
||||
boxes, scores = self.polygons_from_bitmap(pred[batch_index],
|
||||
mask, src_w, src_h)
|
||||
elif self.box_type == 'quad':
|
||||
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
|
||||
src_w, src_h)
|
||||
else:
|
||||
raise ValueError(
|
||||
"box_type can only be one of ['quad', 'poly']")
|
||||
|
||||
boxes_batch.append({'points': boxes})
|
||||
return boxes_batch
|
||||
|
||||
|
||||
class BaseRecLabelDecode:
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
self.reverse = False
|
||||
self.character_str = []
|
||||
|
||||
if character_dict_path is None:
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
else:
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
self.character_str.append(line)
|
||||
if use_space_char:
|
||||
self.character_str.append(" ")
|
||||
dict_character = list(self.character_str)
|
||||
if 'arabic' in character_dict_path:
|
||||
self.reverse = True
|
||||
|
||||
dict_character = self.add_special_char(dict_character)
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
self.character = dict_character
|
||||
|
||||
def pred_reverse(self, pred):
|
||||
pred_re = []
|
||||
c_current = ''
|
||||
for c in pred:
|
||||
if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
|
||||
if c_current != '':
|
||||
pred_re.append(c_current)
|
||||
pred_re.append(c)
|
||||
c_current = ''
|
||||
else:
|
||||
c_current += c
|
||||
if c_current != '':
|
||||
pred_re.append(c_current)
|
||||
|
||||
return ''.join(pred_re[::-1])
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
|
||||
if is_remove_duplicate:
|
||||
selection[1:] = text_index[batch_idx][1:] != text_index[
|
||||
batch_idx][:-1]
|
||||
for ignored_token in ignored_tokens:
|
||||
selection &= text_index[batch_idx] != ignored_token
|
||||
|
||||
char_list = [
|
||||
self.character[text_id]
|
||||
for text_id in text_index[batch_idx][selection]
|
||||
]
|
||||
if text_prob is not None:
|
||||
conf_list = text_prob[batch_idx][selection]
|
||||
else:
|
||||
conf_list = [1] * len(selection)
|
||||
if len(conf_list) == 0:
|
||||
conf_list = [0]
|
||||
|
||||
text = ''.join(char_list)
|
||||
|
||||
if self.reverse: # for arabic rec
|
||||
text = self.pred_reverse(text)
|
||||
|
||||
result_list.append((text, np.mean(conf_list).tolist()))
|
||||
return result_list
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
return [0] # for ctc blank
|
||||
|
||||
|
||||
class CTCLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(CTCLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, tuple) or isinstance(preds, list):
|
||||
preds = preds[-1]
|
||||
if not isinstance(preds, np.ndarray):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['blank'] + dict_character
|
||||
return dict_character
|
||||
427
api/app/core/rag/deepdoc/vision/recognizer.py
Normal file
427
api/app/core/rag/deepdoc/vision/recognizer.py
Normal file
@@ -0,0 +1,427 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
import cv2
|
||||
from functools import cmp_to_key
|
||||
|
||||
|
||||
from app.core.rag.common.file_utils import get_project_base_directory
|
||||
from .operators import * # noqa: F403
|
||||
from .operators import preprocess
|
||||
from . import operators
|
||||
from .ocr import load_model
|
||||
|
||||
class Recognizer:
|
||||
def __init__(self, label_list, task_name, model_dir=None):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not model_dir:
|
||||
model_dir = os.path.join(
|
||||
get_project_base_directory(),
|
||||
"res/deepdoc")
|
||||
self.ort_sess, self.run_options = load_model(model_dir, task_name)
|
||||
self.input_names = [node.name for node in self.ort_sess.get_inputs()]
|
||||
self.output_names = [node.name for node in self.ort_sess.get_outputs()]
|
||||
self.input_shape = self.ort_sess.get_inputs()[0].shape[2:4]
|
||||
self.label_list = label_list
|
||||
|
||||
@staticmethod
|
||||
def sort_Y_firstly(arr, threshold):
|
||||
def cmp(c1, c2):
|
||||
diff = c1["top"] - c2["top"]
|
||||
if abs(diff) < threshold:
|
||||
diff = c1["x0"] - c2["x0"]
|
||||
return diff
|
||||
arr = sorted(arr, key=cmp_to_key(cmp))
|
||||
return arr
|
||||
|
||||
@staticmethod
|
||||
def sort_X_firstly(arr, threshold):
|
||||
def cmp(c1, c2):
|
||||
diff = c1["x0"] - c2["x0"]
|
||||
if abs(diff) < threshold:
|
||||
diff = c1["top"] - c2["top"]
|
||||
return diff
|
||||
arr = sorted(arr, key=cmp_to_key(cmp))
|
||||
return arr
|
||||
|
||||
@staticmethod
|
||||
def sort_C_firstly(arr, thr=0):
|
||||
# sort using y1 first and then x1
|
||||
# sorted(arr, key=lambda r: (r["x0"], r["top"]))
|
||||
arr = Recognizer.sort_X_firstly(arr, thr)
|
||||
for i in range(len(arr) - 1):
|
||||
for j in range(i, -1, -1):
|
||||
# restore the order using th
|
||||
if "C" not in arr[j] or "C" not in arr[j + 1]:
|
||||
continue
|
||||
if arr[j + 1]["C"] < arr[j]["C"] \
|
||||
or (
|
||||
arr[j + 1]["C"] == arr[j]["C"]
|
||||
and arr[j + 1]["top"] < arr[j]["top"]
|
||||
):
|
||||
tmp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = tmp
|
||||
return arr
|
||||
|
||||
@staticmethod
|
||||
def sort_R_firstly(arr, thr=0):
|
||||
# sort using y1 first and then x1
|
||||
# sorted(arr, key=lambda r: (r["top"], r["x0"]))
|
||||
arr = Recognizer.sort_Y_firstly(arr, thr)
|
||||
for i in range(len(arr) - 1):
|
||||
for j in range(i, -1, -1):
|
||||
if "R" not in arr[j] or "R" not in arr[j + 1]:
|
||||
continue
|
||||
if arr[j + 1]["R"] < arr[j]["R"] \
|
||||
or (
|
||||
arr[j + 1]["R"] == arr[j]["R"]
|
||||
and arr[j + 1]["x0"] < arr[j]["x0"]
|
||||
):
|
||||
tmp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = tmp
|
||||
return arr
|
||||
|
||||
@staticmethod
|
||||
def overlapped_area(a, b, ratio=True):
|
||||
tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"]
|
||||
if b["x0"] > x1 or b["x1"] < x0:
|
||||
return 0
|
||||
if b["bottom"] < tp or b["top"] > btm:
|
||||
return 0
|
||||
x0_ = max(b["x0"], x0)
|
||||
x1_ = min(b["x1"], x1)
|
||||
assert x0_ <= x1_, "Bbox mismatch! T:{},B:{},X0:{},X1:{} ==> {}".format(
|
||||
tp, btm, x0, x1, b)
|
||||
tp_ = max(b["top"], tp)
|
||||
btm_ = min(b["bottom"], btm)
|
||||
assert tp_ <= btm_, "Bbox mismatch! T:{},B:{},X0:{},X1:{} => {}".format(
|
||||
tp, btm, x0, x1, b)
|
||||
ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
|
||||
x0 != 0 and btm - tp != 0 else 0
|
||||
if ov > 0 and ratio:
|
||||
ov /= (x1 - x0) * (btm - tp)
|
||||
return ov
|
||||
|
||||
@staticmethod
|
||||
def layouts_cleanup(boxes, layouts, far=2, thr=0.7):
|
||||
def not_overlapped(a, b):
|
||||
return any([a["x1"] < b["x0"],
|
||||
a["x0"] > b["x1"],
|
||||
a["bottom"] < b["top"],
|
||||
a["top"] > b["bottom"]])
|
||||
|
||||
i = 0
|
||||
while i + 1 < len(layouts):
|
||||
j = i + 1
|
||||
while j < min(i + far, len(layouts)) \
|
||||
and (layouts[i].get("type", "") != layouts[j].get("type", "")
|
||||
or not_overlapped(layouts[i], layouts[j])):
|
||||
j += 1
|
||||
if j >= min(i + far, len(layouts)):
|
||||
i += 1
|
||||
continue
|
||||
if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \
|
||||
and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if layouts[i].get("score") and layouts[j].get("score"):
|
||||
if layouts[i]["score"] > layouts[j]["score"]:
|
||||
layouts.pop(j)
|
||||
else:
|
||||
layouts.pop(i)
|
||||
continue
|
||||
|
||||
area_i, area_i_1 = 0, 0
|
||||
for b in boxes:
|
||||
if not not_overlapped(b, layouts[i]):
|
||||
area_i += Recognizer.overlapped_area(b, layouts[i], False)
|
||||
if not not_overlapped(b, layouts[j]):
|
||||
area_i_1 += Recognizer.overlapped_area(b, layouts[j], False)
|
||||
|
||||
if area_i > area_i_1:
|
||||
layouts.pop(j)
|
||||
else:
|
||||
layouts.pop(i)
|
||||
|
||||
return layouts
|
||||
|
||||
def create_inputs(self, imgs, im_info):
|
||||
"""generate input for different model type
|
||||
Args:
|
||||
imgs (list(numpy)): list of images (np.ndarray)
|
||||
im_info (list(dict)): list of image info
|
||||
Returns:
|
||||
inputs (dict): input of model
|
||||
"""
|
||||
inputs = {}
|
||||
|
||||
im_shape = []
|
||||
scale_factor = []
|
||||
if len(imgs) == 1:
|
||||
inputs['image'] = np.array((imgs[0],)).astype('float32')
|
||||
inputs['im_shape'] = np.array(
|
||||
(im_info[0]['im_shape'],)).astype('float32')
|
||||
inputs['scale_factor'] = np.array(
|
||||
(im_info[0]['scale_factor'],)).astype('float32')
|
||||
return inputs
|
||||
|
||||
im_shape = np.array([info['im_shape'] for info in im_info], dtype='float32')
|
||||
scale_factor = np.array([info['scale_factor'] for info in im_info], dtype='float32')
|
||||
|
||||
inputs['im_shape'] = np.concatenate(im_shape, axis=0)
|
||||
inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
|
||||
|
||||
imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
|
||||
max_shape_h = max([e[0] for e in imgs_shape])
|
||||
max_shape_w = max([e[1] for e in imgs_shape])
|
||||
padding_imgs = []
|
||||
for img in imgs:
|
||||
im_c, im_h, im_w = img.shape[:]
|
||||
padding_im = np.zeros(
|
||||
(im_c, max_shape_h, max_shape_w), dtype=np.float32)
|
||||
padding_im[:, :im_h, :im_w] = img
|
||||
padding_imgs.append(padding_im)
|
||||
inputs['image'] = np.stack(padding_imgs, axis=0)
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def find_overlapped(box, boxes_sorted_by_y, naive=False):
|
||||
if not boxes_sorted_by_y:
|
||||
return
|
||||
bxs = boxes_sorted_by_y
|
||||
s, e, ii = 0, len(bxs), 0
|
||||
while s < e and not naive:
|
||||
ii = (e + s) // 2
|
||||
pv = bxs[ii]
|
||||
if box["bottom"] < pv["top"]:
|
||||
e = ii
|
||||
continue
|
||||
if box["top"] > pv["bottom"]:
|
||||
s = ii + 1
|
||||
continue
|
||||
break
|
||||
while s < ii:
|
||||
if box["top"] > bxs[s]["bottom"]:
|
||||
s += 1
|
||||
break
|
||||
while e - 1 > ii:
|
||||
if box["bottom"] < bxs[e - 1]["top"]:
|
||||
e -= 1
|
||||
break
|
||||
|
||||
max_overlapped_i, max_overlapped = None, 0
|
||||
for i in range(s, e):
|
||||
ov = Recognizer.overlapped_area(bxs[i], box)
|
||||
if ov <= max_overlapped:
|
||||
continue
|
||||
max_overlapped_i = i
|
||||
max_overlapped = ov
|
||||
|
||||
return max_overlapped_i
|
||||
|
||||
@staticmethod
|
||||
def find_horizontally_tightest_fit(box, boxes):
|
||||
if not boxes:
|
||||
return
|
||||
min_dis, min_i = 1000000, None
|
||||
for i,b in enumerate(boxes):
|
||||
if box.get("layoutno", "0") != b.get("layoutno", "0"):
|
||||
continue
|
||||
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
|
||||
if dis < min_dis:
|
||||
min_i = i
|
||||
min_dis = dis
|
||||
return min_i
|
||||
|
||||
@staticmethod
|
||||
def find_overlapped_with_threshold(box, boxes, thr=0.3):
|
||||
if not boxes:
|
||||
return
|
||||
max_overlapped_i, max_overlapped, _max_overlapped = None, thr, 0
|
||||
s, e = 0, len(boxes)
|
||||
for i in range(s, e):
|
||||
ov = Recognizer.overlapped_area(box, boxes[i])
|
||||
_ov = Recognizer.overlapped_area(boxes[i], box)
|
||||
if (ov, _ov) < (max_overlapped, _max_overlapped):
|
||||
continue
|
||||
max_overlapped_i = i
|
||||
max_overlapped = ov
|
||||
_max_overlapped = _ov
|
||||
|
||||
return max_overlapped_i
|
||||
|
||||
def preprocess(self, image_list):
|
||||
inputs = []
|
||||
if "scale_factor" in self.input_names:
|
||||
preprocess_ops = []
|
||||
for op_info in [
|
||||
{'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
|
||||
{'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
|
||||
{'type': 'Permute'},
|
||||
{'stride': 32, 'type': 'PadStride'}
|
||||
]:
|
||||
new_op_info = op_info.copy()
|
||||
op_type = new_op_info.pop('type')
|
||||
preprocess_ops.append(getattr(operators, op_type)(**new_op_info))
|
||||
|
||||
for im_path in image_list:
|
||||
im, im_info = preprocess(im_path, preprocess_ops)
|
||||
inputs.append({"image": np.array((im,)).astype('float32'),
|
||||
"scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
|
||||
else:
|
||||
hh, ww = self.input_shape
|
||||
for img in image_list:
|
||||
h, w = img.shape[:2]
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(np.array(img).astype('float32'), (ww, hh))
|
||||
# Scale input pixel values to 0 to 1
|
||||
img /= 255.0
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = img[np.newaxis, :, :, :].astype(np.float32)
|
||||
inputs.append({self.input_names[0]: img, "scale_factor": [w/ww, h/hh]})
|
||||
return inputs
|
||||
|
||||
def postprocess(self, boxes, inputs, thr):
|
||||
if "scale_factor" in self.input_names:
|
||||
bb = []
|
||||
for b in boxes:
|
||||
clsid, bbox, score = int(b[0]), b[2:], b[1]
|
||||
if score < thr:
|
||||
continue
|
||||
if clsid >= len(self.label_list):
|
||||
continue
|
||||
bb.append({
|
||||
"type": self.label_list[clsid].lower(),
|
||||
"bbox": [float(t) for t in bbox.tolist()],
|
||||
"score": float(score)
|
||||
})
|
||||
return bb
|
||||
|
||||
def xywh2xyxy(x):
|
||||
# [x, y, w, h] to [x1, y1, x2, y2]
|
||||
y = np.copy(x)
|
||||
y[:, 0] = x[:, 0] - x[:, 2] / 2
|
||||
y[:, 1] = x[:, 1] - x[:, 3] / 2
|
||||
y[:, 2] = x[:, 0] + x[:, 2] / 2
|
||||
y[:, 3] = x[:, 1] + x[:, 3] / 2
|
||||
return y
|
||||
|
||||
def compute_iou(box, boxes):
|
||||
# Compute xmin, ymin, xmax, ymax for both boxes
|
||||
xmin = np.maximum(box[0], boxes[:, 0])
|
||||
ymin = np.maximum(box[1], boxes[:, 1])
|
||||
xmax = np.minimum(box[2], boxes[:, 2])
|
||||
ymax = np.minimum(box[3], boxes[:, 3])
|
||||
|
||||
# Compute intersection area
|
||||
intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
|
||||
|
||||
# Compute union area
|
||||
box_area = (box[2] - box[0]) * (box[3] - box[1])
|
||||
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
union_area = box_area + boxes_area - intersection_area
|
||||
|
||||
# Compute IoU
|
||||
iou = intersection_area / union_area
|
||||
|
||||
return iou
|
||||
|
||||
def iou_filter(boxes, scores, iou_threshold):
|
||||
sorted_indices = np.argsort(scores)[::-1]
|
||||
|
||||
keep_boxes = []
|
||||
while sorted_indices.size > 0:
|
||||
# Pick the last box
|
||||
box_id = sorted_indices[0]
|
||||
keep_boxes.append(box_id)
|
||||
|
||||
# Compute IoU of the picked box with the rest
|
||||
ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
|
||||
|
||||
# Remove boxes with IoU over the threshold
|
||||
keep_indices = np.where(ious < iou_threshold)[0]
|
||||
|
||||
# print(keep_indices.shape, sorted_indices.shape)
|
||||
sorted_indices = sorted_indices[keep_indices + 1]
|
||||
|
||||
return keep_boxes
|
||||
|
||||
boxes = np.squeeze(boxes).T
|
||||
# Filter out object confidence scores below threshold
|
||||
scores = np.max(boxes[:, 4:], axis=1)
|
||||
boxes = boxes[scores > thr, :]
|
||||
scores = scores[scores > thr]
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
|
||||
# Get the class with the highest confidence
|
||||
class_ids = np.argmax(boxes[:, 4:], axis=1)
|
||||
boxes = boxes[:, :4]
|
||||
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]])
|
||||
boxes = np.multiply(boxes, input_shape, dtype=np.float32)
|
||||
boxes = xywh2xyxy(boxes)
|
||||
|
||||
unique_class_ids = np.unique(class_ids)
|
||||
indices = []
|
||||
for class_id in unique_class_ids:
|
||||
class_indices = np.where(class_ids == class_id)[0]
|
||||
class_boxes = boxes[class_indices, :]
|
||||
class_scores = scores[class_indices]
|
||||
class_keep_boxes = iou_filter(class_boxes, class_scores, 0.2)
|
||||
indices.extend(class_indices[class_keep_boxes])
|
||||
|
||||
return [{
|
||||
"type": self.label_list[class_ids[i]].lower(),
|
||||
"bbox": [float(t) for t in boxes[i].tolist()],
|
||||
"score": float(scores[i])
|
||||
} for i in indices]
|
||||
|
||||
def close(self):
|
||||
logging.info("Close recognizer.")
|
||||
if hasattr(self, "ort_sess"):
|
||||
del self.ort_sess
|
||||
gc.collect()
|
||||
|
||||
def __call__(self, image_list, thr=0.7, batch_size=16):
|
||||
res = []
|
||||
images = []
|
||||
for i in range(len(image_list)):
|
||||
if not isinstance(image_list[i], np.ndarray):
|
||||
images.append(np.array(image_list[i]))
|
||||
else:
|
||||
images.append(image_list[i])
|
||||
|
||||
batch_loop_cnt = math.ceil(float(len(images)) / batch_size)
|
||||
for i in range(batch_loop_cnt):
|
||||
start_index = i * batch_size
|
||||
end_index = min((i + 1) * batch_size, len(images))
|
||||
batch_image_list = images[start_index:end_index]
|
||||
inputs = self.preprocess(batch_image_list)
|
||||
logging.debug("preprocess")
|
||||
for ins in inputs:
|
||||
bb = self.postprocess(self.ort_sess.run(None, {k:v for k,v in ins.items() if k in self.input_names}, self.run_options)[0], ins, thr)
|
||||
res.append(bb)
|
||||
|
||||
#seeit.save_results(image_list, res, self.label_list, threshold=thr)
|
||||
|
||||
return res
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
|
||||
71
api/app/core/rag/deepdoc/vision/seeit.py
Normal file
71
api/app/core/rag/deepdoc/vision/seeit.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import logging
|
||||
import os
|
||||
import PIL
|
||||
from PIL import ImageDraw
|
||||
|
||||
|
||||
def save_results(image_list, results, labels, output_dir='output/', threshold=0.5):
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
for idx, im in enumerate(image_list):
|
||||
im = draw_box(im, results[idx], labels, threshold=threshold)
|
||||
|
||||
out_path = os.path.join(output_dir, f"{idx}.jpg")
|
||||
im.save(out_path, quality=95)
|
||||
logging.debug("save result to: " + out_path)
|
||||
|
||||
|
||||
def draw_box(im, result, labels, threshold=0.5):
|
||||
draw_thickness = min(im.size) // 320
|
||||
draw = ImageDraw.Draw(im)
|
||||
color_list = get_color_map_list(len(labels))
|
||||
clsid2color = {n.lower():color_list[i] for i,n in enumerate(labels)}
|
||||
result = [r for r in result if r["score"] >= threshold]
|
||||
|
||||
for dt in result:
|
||||
color = tuple(clsid2color[dt["type"]])
|
||||
xmin, ymin, xmax, ymax = dt["bbox"]
|
||||
draw.line(
|
||||
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
|
||||
(xmin, ymin)],
|
||||
width=draw_thickness,
|
||||
fill=color)
|
||||
|
||||
# draw label
|
||||
text = "{} {:.4f}".format(dt["type"], dt["score"])
|
||||
tw, th = imagedraw_textsize_c(draw, text)
|
||||
draw.rectangle(
|
||||
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
|
||||
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
|
||||
return im
|
||||
|
||||
|
||||
def get_color_map_list(num_classes):
|
||||
"""
|
||||
Args:
|
||||
num_classes (int): number of class
|
||||
Returns:
|
||||
color_map (list): RGB color list
|
||||
"""
|
||||
color_map = num_classes * [0, 0, 0]
|
||||
for i in range(0, num_classes):
|
||||
j = 0
|
||||
lab = i
|
||||
while lab:
|
||||
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
|
||||
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
|
||||
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
|
||||
j += 1
|
||||
lab >>= 3
|
||||
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
|
||||
return color_map
|
||||
|
||||
|
||||
def imagedraw_textsize_c(draw, text):
|
||||
if int(PIL.__version__.split('.')[0]) < 10:
|
||||
tw, th = draw.textsize(text)
|
||||
else:
|
||||
left, top, right, bottom = draw.textbbox((0, 0), text)
|
||||
tw, th = right - left, bottom - top
|
||||
|
||||
return tw, th
|
||||
77
api/app/core/rag/deepdoc/vision/t_ocr.py
Normal file
77
api/app/core/rag/deepdoc/vision/t_ocr.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(
|
||||
0,
|
||||
os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(
|
||||
os.path.abspath(__file__)),
|
||||
'../../')))
|
||||
|
||||
from .seeit import draw_box
|
||||
from . import OCR, init_in_out
|
||||
import argparse
|
||||
import numpy as np
|
||||
import trio
|
||||
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,2' #2 gpus, uncontinuous
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = '' #cpu
|
||||
|
||||
|
||||
def main(args):
|
||||
import torch.cuda
|
||||
|
||||
cuda_devices = torch.cuda.device_count()
|
||||
limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
|
||||
ocr = OCR()
|
||||
images, outputs = init_in_out(args)
|
||||
|
||||
def __ocr(i, id, img):
|
||||
print("Task {} start".format(i))
|
||||
bxs = ocr(np.array(img), id)
|
||||
bxs = [(line[0], line[1][0]) for line in bxs]
|
||||
bxs = [{
|
||||
"text": t,
|
||||
"bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]],
|
||||
"type": "ocr",
|
||||
"score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]]
|
||||
img = draw_box(images[i], bxs, ["ocr"], 1.)
|
||||
img.save(outputs[i], quality=95)
|
||||
with open(outputs[i] + ".txt", "w+", encoding='utf-8') as f:
|
||||
f.write("\n".join([o["text"] for o in bxs]))
|
||||
|
||||
print("Task {} done".format(i))
|
||||
|
||||
async def __ocr_thread(i, id, img, limiter = None):
|
||||
if limiter:
|
||||
async with limiter:
|
||||
print("Task {} use device {}".format(i, id))
|
||||
await trio.to_thread.run_sync(lambda: __ocr(i, id, img))
|
||||
else:
|
||||
__ocr(i, id, img)
|
||||
|
||||
async def __ocr_launcher():
|
||||
if cuda_devices > 1:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i, img in enumerate(images):
|
||||
nursery.start_soon(__ocr_thread, i, i % cuda_devices, img, limiter[i % cuda_devices])
|
||||
await trio.sleep(0.1)
|
||||
else:
|
||||
for i, img in enumerate(images):
|
||||
await __ocr_thread(i, 0, img)
|
||||
|
||||
trio.run(__ocr_launcher)
|
||||
|
||||
print("OCR tasks are all done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--inputs',
|
||||
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
|
||||
required=True)
|
||||
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'",
|
||||
default="./ocr_outputs")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
170
api/app/core/rag/deepdoc/vision/t_recognizer.py
Normal file
170
api/app/core/rag/deepdoc/vision/t_recognizer.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(
|
||||
0,
|
||||
os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(
|
||||
os.path.abspath(__file__)),
|
||||
'../../')))
|
||||
|
||||
from .seeit import draw_box
|
||||
from . import LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out
|
||||
import argparse
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main(args):
|
||||
images, outputs = init_in_out(args)
|
||||
if args.mode.lower() == "layout":
|
||||
detr = LayoutRecognizer("layout")
|
||||
layouts = detr.forward(images, thr=float(args.threshold))
|
||||
if args.mode.lower() == "tsr":
|
||||
detr = TableStructureRecognizer()
|
||||
ocr = OCR()
|
||||
layouts = detr(images, thr=float(args.threshold))
|
||||
for i, lyt in enumerate(layouts):
|
||||
if args.mode.lower() == "tsr":
|
||||
#lyt = [t for t in lyt if t["type"] == "table column"]
|
||||
html = get_table_html(images[i], lyt, ocr)
|
||||
with open(outputs[i] + ".html", "w+", encoding='utf-8') as f:
|
||||
f.write(html)
|
||||
lyt = [{
|
||||
"type": t["label"],
|
||||
"bbox": [t["x0"], t["top"], t["x1"], t["bottom"]],
|
||||
"score": t["score"]
|
||||
} for t in lyt]
|
||||
img = draw_box(images[i], lyt, detr.labels, float(args.threshold))
|
||||
img.save(outputs[i], quality=95)
|
||||
logging.info("save result to: " + outputs[i])
|
||||
|
||||
|
||||
def get_table_html(img, tb_cpns, ocr):
|
||||
boxes = ocr(np.array(img))
|
||||
boxes = LayoutRecognizer.sort_Y_firstly(
|
||||
[{"x0": b[0][0], "x1": b[1][0],
|
||||
"top": b[0][1], "text": t[0],
|
||||
"bottom": b[-1][1],
|
||||
"layout_type": "table",
|
||||
"page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
|
||||
np.mean([b[-1][1] - b[0][1] for b, _ in boxes]) / 3
|
||||
)
|
||||
|
||||
def gather(kwd, fzy=10, ption=0.6):
|
||||
nonlocal boxes
|
||||
eles = LayoutRecognizer.sort_Y_firstly(
|
||||
[r for r in tb_cpns if re.match(kwd, r["label"])], fzy)
|
||||
eles = LayoutRecognizer.layouts_cleanup(boxes, eles, 5, ption)
|
||||
return LayoutRecognizer.sort_Y_firstly(eles, 0)
|
||||
|
||||
headers = gather(r".*header$")
|
||||
rows = gather(r".* (row|header)")
|
||||
spans = gather(r".*spanning")
|
||||
clmns = sorted([r for r in tb_cpns if re.match(
|
||||
r"table column$", r["label"])], key=lambda x: x["x0"])
|
||||
clmns = LayoutRecognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
|
||||
|
||||
for b in boxes:
|
||||
ii = LayoutRecognizer.find_overlapped_with_threshold(b, rows, thr=0.3)
|
||||
if ii is not None:
|
||||
b["R"] = ii
|
||||
b["R_top"] = rows[ii]["top"]
|
||||
b["R_bott"] = rows[ii]["bottom"]
|
||||
|
||||
ii = LayoutRecognizer.find_overlapped_with_threshold(b, headers, thr=0.3)
|
||||
if ii is not None:
|
||||
b["H_top"] = headers[ii]["top"]
|
||||
b["H_bott"] = headers[ii]["bottom"]
|
||||
b["H_left"] = headers[ii]["x0"]
|
||||
b["H_right"] = headers[ii]["x1"]
|
||||
b["H"] = ii
|
||||
|
||||
ii = LayoutRecognizer.find_horizontally_tightest_fit(b, clmns)
|
||||
if ii is not None:
|
||||
b["C"] = ii
|
||||
b["C_left"] = clmns[ii]["x0"]
|
||||
b["C_right"] = clmns[ii]["x1"]
|
||||
|
||||
ii = LayoutRecognizer.find_overlapped_with_threshold(b, spans, thr=0.3)
|
||||
if ii is not None:
|
||||
b["H_top"] = spans[ii]["top"]
|
||||
b["H_bott"] = spans[ii]["bottom"]
|
||||
b["H_left"] = spans[ii]["x0"]
|
||||
b["H_right"] = spans[ii]["x1"]
|
||||
b["SP"] = ii
|
||||
|
||||
html = """
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
._table_1nkzy_11 {
|
||||
margin: auto;
|
||||
width: 70%%;
|
||||
padding: 10px;
|
||||
}
|
||||
._table_1nkzy_11 p {
|
||||
margin-bottom: 50px;
|
||||
border: 1px solid #e1e1e1;
|
||||
}
|
||||
|
||||
caption {
|
||||
color: #6ac1ca;
|
||||
font-size: 20px;
|
||||
height: 50px;
|
||||
line-height: 50px;
|
||||
font-weight: 600;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
._table_1nkzy_11 table {
|
||||
width: 100%%;
|
||||
border-collapse: collapse;
|
||||
}
|
||||
|
||||
th {
|
||||
color: #fff;
|
||||
background-color: #6ac1ca;
|
||||
}
|
||||
|
||||
td:hover {
|
||||
background: #c1e8e8;
|
||||
}
|
||||
|
||||
tr:nth-child(even) {
|
||||
background-color: #f2f2f2;
|
||||
}
|
||||
|
||||
._table_1nkzy_11 th,
|
||||
._table_1nkzy_11 td {
|
||||
text-align: center;
|
||||
border: 1px solid #ddd;
|
||||
padding: 8px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
%s
|
||||
</body>
|
||||
</html>
|
||||
""" % TableStructureRecognizer.construct_table(boxes, html=True)
|
||||
return html
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--inputs',
|
||||
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
|
||||
required=True)
|
||||
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'",
|
||||
default="./layouts_outputs")
|
||||
parser.add_argument(
|
||||
'--threshold',
|
||||
help="A threshold to filter out detections. Default: 0.5",
|
||||
default=0.5)
|
||||
parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"],
|
||||
default="layout")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
597
api/app/core/rag/deepdoc/vision/table_structure_recognizer.py
Normal file
597
api/app/core/rag/deepdoc/vision/table_structure_recognizer.py
Normal file
@@ -0,0 +1,597 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from app.core.rag.common.file_utils import get_project_base_directory
|
||||
from app.core.rag.nlp import rag_tokenizer
|
||||
|
||||
from .recognizer import Recognizer
|
||||
|
||||
|
||||
class TableStructureRecognizer(Recognizer):
|
||||
labels = [
|
||||
"table",
|
||||
"table column",
|
||||
"table row",
|
||||
"table column header",
|
||||
"table projected row header",
|
||||
"table spanning cell",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
super().__init__(self.labels, "tsr", os.path.join(get_project_base_directory(), "res/deepdoc"))
|
||||
except Exception:
|
||||
super().__init__(
|
||||
self.labels,
|
||||
"tsr",
|
||||
snapshot_download(
|
||||
repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(get_project_base_directory(), "res/deepdoc"),
|
||||
local_dir_use_symlinks=False,
|
||||
),
|
||||
)
|
||||
|
||||
def __call__(self, images, thr=0.2):
|
||||
table_structure_recognizer_type = os.getenv("TABLE_STRUCTURE_RECOGNIZER_TYPE", "onnx").lower()
|
||||
if table_structure_recognizer_type not in ["onnx", "ascend"]:
|
||||
raise RuntimeError("Unsupported table structure recognizer type.")
|
||||
|
||||
if table_structure_recognizer_type == "onnx":
|
||||
logging.debug("Using Onnx table structure recognizer")
|
||||
tbls = super().__call__(images, thr)
|
||||
else: # ascend
|
||||
logging.debug("Using Ascend table structure recognizer")
|
||||
tbls = self._run_ascend_tsr(images, thr)
|
||||
|
||||
res = []
|
||||
# align left&right for rows, align top&bottom for columns
|
||||
for tbl in tbls:
|
||||
lts = [
|
||||
{
|
||||
"label": b["type"],
|
||||
"score": b["score"],
|
||||
"x0": b["bbox"][0],
|
||||
"x1": b["bbox"][2],
|
||||
"top": b["bbox"][1],
|
||||
"bottom": b["bbox"][-1],
|
||||
}
|
||||
for b in tbl
|
||||
]
|
||||
if not lts:
|
||||
continue
|
||||
|
||||
left = [b["x0"] for b in lts if b["label"].find("row") > 0 or b["label"].find("header") > 0]
|
||||
right = [b["x1"] for b in lts if b["label"].find("row") > 0 or b["label"].find("header") > 0]
|
||||
if not left:
|
||||
continue
|
||||
left = np.mean(left) if len(left) > 4 else np.min(left)
|
||||
right = np.mean(right) if len(right) > 4 else np.max(right)
|
||||
for b in lts:
|
||||
if b["label"].find("row") > 0 or b["label"].find("header") > 0:
|
||||
if b["x0"] > left:
|
||||
b["x0"] = left
|
||||
if b["x1"] < right:
|
||||
b["x1"] = right
|
||||
|
||||
top = [b["top"] for b in lts if b["label"] == "table column"]
|
||||
bottom = [b["bottom"] for b in lts if b["label"] == "table column"]
|
||||
if not top:
|
||||
res.append(lts)
|
||||
continue
|
||||
top = np.median(top) if len(top) > 4 else np.min(top)
|
||||
bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom)
|
||||
for b in lts:
|
||||
if b["label"] == "table column":
|
||||
if b["top"] > top:
|
||||
b["top"] = top
|
||||
if b["bottom"] < bottom:
|
||||
b["bottom"] = bottom
|
||||
|
||||
res.append(lts)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def is_caption(bx):
|
||||
patt = [r"[图表]+[ 0-9::]{2,}"]
|
||||
if any([re.match(p, bx["text"].strip()) for p in patt]) or bx.get("layout_type", "").find("caption") >= 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def blockType(b):
|
||||
patt = [
|
||||
("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}年$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"),
|
||||
("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"),
|
||||
(r"^第*[一二三四1-4]季度$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"),
|
||||
("^[0-9.,+%/ -]+$", "Nu"),
|
||||
(r"^[0-9A-Z/\._~-]+$", "Ca"),
|
||||
(r"^[A-Z]*[a-z' -]+$", "En"),
|
||||
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
|
||||
(r"^.{1}$", "Sg"),
|
||||
]
|
||||
for p, n in patt:
|
||||
if re.search(p, b["text"].strip()):
|
||||
return n
|
||||
tks = [t for t in rag_tokenizer.tokenize(b["text"]).split() if len(t) > 1]
|
||||
if len(tks) > 3:
|
||||
if len(tks) < 12:
|
||||
return "Tx"
|
||||
else:
|
||||
return "Lx"
|
||||
|
||||
if len(tks) == 1 and rag_tokenizer.tag(tks[0]) == "nr":
|
||||
return "Nr"
|
||||
|
||||
return "Ot"
|
||||
|
||||
@staticmethod
|
||||
def construct_table(boxes, is_english=False, html=True, **kwargs):
|
||||
cap = ""
|
||||
i = 0
|
||||
while i < len(boxes):
|
||||
if TableStructureRecognizer.is_caption(boxes[i]):
|
||||
if is_english:
|
||||
cap + " "
|
||||
cap += boxes[i]["text"]
|
||||
boxes.pop(i)
|
||||
i -= 1
|
||||
i += 1
|
||||
|
||||
if not boxes:
|
||||
return []
|
||||
for b in boxes:
|
||||
b["btype"] = TableStructureRecognizer.blockType(b)
|
||||
max_type = Counter([b["btype"] for b in boxes]).items()
|
||||
max_type = max(max_type, key=lambda x: x[1])[0] if max_type else ""
|
||||
logging.debug("MAXTYPE: " + max_type)
|
||||
|
||||
rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
|
||||
rowh = np.min(rowh) if rowh else 0
|
||||
boxes = Recognizer.sort_R_firstly(boxes, rowh / 2)
|
||||
# for b in boxes:print(b)
|
||||
boxes[0]["rn"] = 0
|
||||
rows = [[boxes[0]]]
|
||||
btm = boxes[0]["bottom"]
|
||||
for b in boxes[1:]:
|
||||
b["rn"] = len(rows) - 1
|
||||
lst_r = rows[-1]
|
||||
if lst_r[-1].get("R", "") != b.get("R", "") or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")): # new row
|
||||
btm = b["bottom"]
|
||||
b["rn"] += 1
|
||||
rows.append([b])
|
||||
continue
|
||||
btm = (btm + b["bottom"]) / 2.0
|
||||
rows[-1].append(b)
|
||||
|
||||
colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
|
||||
colwm = np.min(colwm) if colwm else 0
|
||||
crosspage = len(set([b["page_number"] for b in boxes])) > 1
|
||||
if crosspage:
|
||||
boxes = Recognizer.sort_X_firstly(boxes, colwm / 2)
|
||||
else:
|
||||
boxes = Recognizer.sort_C_firstly(boxes, colwm / 2)
|
||||
boxes[0]["cn"] = 0
|
||||
cols = [[boxes[0]]]
|
||||
right = boxes[0]["x1"]
|
||||
for b in boxes[1:]:
|
||||
b["cn"] = len(cols) - 1
|
||||
lst_c = cols[-1]
|
||||
if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1]["page_number"]) or (
|
||||
b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")
|
||||
): # new col
|
||||
right = b["x1"]
|
||||
b["cn"] += 1
|
||||
cols.append([b])
|
||||
continue
|
||||
right = (right + b["x1"]) / 2.0
|
||||
cols[-1].append(b)
|
||||
|
||||
tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
|
||||
for b in boxes:
|
||||
tbl[b["rn"]][b["cn"]].append(b)
|
||||
|
||||
if len(rows) >= 4:
|
||||
# remove single in column
|
||||
j = 0
|
||||
while j < len(tbl[0]):
|
||||
e, ii = 0, 0
|
||||
for i in range(len(tbl)):
|
||||
if tbl[i][j]:
|
||||
e += 1
|
||||
ii = i
|
||||
if e > 1:
|
||||
break
|
||||
if e > 1:
|
||||
j += 1
|
||||
continue
|
||||
f = (j > 0 and tbl[ii][j - 1] and tbl[ii][j - 1][0].get("text")) or j == 0
|
||||
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii][j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
|
||||
if f and ff:
|
||||
j += 1
|
||||
continue
|
||||
bx = tbl[ii][j][0]
|
||||
logging.debug("Relocate column single: " + bx["text"])
|
||||
# j column only has one value
|
||||
left, right = 100000, 100000
|
||||
if j > 0 and not f:
|
||||
for i in range(len(tbl)):
|
||||
if tbl[i][j - 1]:
|
||||
left = min(left, np.min([bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
|
||||
if j + 1 < len(tbl[0]) and not ff:
|
||||
for i in range(len(tbl)):
|
||||
if tbl[i][j + 1]:
|
||||
right = min(right, np.min([a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
|
||||
assert left < 100000 or right < 100000
|
||||
if left < right:
|
||||
for jj in range(j, len(tbl[0])):
|
||||
for i in range(len(tbl)):
|
||||
for a in tbl[i][jj]:
|
||||
a["cn"] -= 1
|
||||
if tbl[ii][j - 1]:
|
||||
tbl[ii][j - 1].extend(tbl[ii][j])
|
||||
else:
|
||||
tbl[ii][j - 1] = tbl[ii][j]
|
||||
for i in range(len(tbl)):
|
||||
tbl[i].pop(j)
|
||||
|
||||
else:
|
||||
for jj in range(j + 1, len(tbl[0])):
|
||||
for i in range(len(tbl)):
|
||||
for a in tbl[i][jj]:
|
||||
a["cn"] -= 1
|
||||
if tbl[ii][j + 1]:
|
||||
tbl[ii][j + 1].extend(tbl[ii][j])
|
||||
else:
|
||||
tbl[ii][j + 1] = tbl[ii][j]
|
||||
for i in range(len(tbl)):
|
||||
tbl[i].pop(j)
|
||||
cols.pop(j)
|
||||
assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (len(cols), len(tbl[0]))
|
||||
|
||||
if len(cols) >= 4:
|
||||
# remove single in row
|
||||
i = 0
|
||||
while i < len(tbl):
|
||||
e, jj = 0, 0
|
||||
for j in range(len(tbl[i])):
|
||||
if tbl[i][j]:
|
||||
e += 1
|
||||
jj = j
|
||||
if e > 1:
|
||||
break
|
||||
if e > 1:
|
||||
i += 1
|
||||
continue
|
||||
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1][jj][0].get("text")) or i == 0
|
||||
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1][jj][0].get("text")) or i + 1 >= len(tbl)
|
||||
if f and ff:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
bx = tbl[i][jj][0]
|
||||
logging.debug("Relocate row single: " + bx["text"])
|
||||
# i row only has one value
|
||||
up, down = 100000, 100000
|
||||
if i > 0 and not f:
|
||||
for j in range(len(tbl[i - 1])):
|
||||
if tbl[i - 1][j]:
|
||||
up = min(up, np.min([bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
|
||||
if i + 1 < len(tbl) and not ff:
|
||||
for j in range(len(tbl[i + 1])):
|
||||
if tbl[i + 1][j]:
|
||||
down = min(down, np.min([a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
|
||||
assert up < 100000 or down < 100000
|
||||
if up < down:
|
||||
for ii in range(i, len(tbl)):
|
||||
for j in range(len(tbl[ii])):
|
||||
for a in tbl[ii][j]:
|
||||
a["rn"] -= 1
|
||||
if tbl[i - 1][jj]:
|
||||
tbl[i - 1][jj].extend(tbl[i][jj])
|
||||
else:
|
||||
tbl[i - 1][jj] = tbl[i][jj]
|
||||
tbl.pop(i)
|
||||
|
||||
else:
|
||||
for ii in range(i + 1, len(tbl)):
|
||||
for j in range(len(tbl[ii])):
|
||||
for a in tbl[ii][j]:
|
||||
a["rn"] -= 1
|
||||
if tbl[i + 1][jj]:
|
||||
tbl[i + 1][jj].extend(tbl[i][jj])
|
||||
else:
|
||||
tbl[i + 1][jj] = tbl[i][jj]
|
||||
tbl.pop(i)
|
||||
rows.pop(i)
|
||||
|
||||
# which rows are headers
|
||||
hdset = set([])
|
||||
for i in range(len(tbl)):
|
||||
cnt, h = 0, 0
|
||||
for j, arr in enumerate(tbl[i]):
|
||||
if not arr:
|
||||
continue
|
||||
cnt += 1
|
||||
if max_type == "Nu" and arr[0]["btype"] == "Nu":
|
||||
continue
|
||||
if any([a.get("H") for a in arr]) or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
|
||||
h += 1
|
||||
if h / cnt > 0.5:
|
||||
hdset.add(i)
|
||||
|
||||
if html:
|
||||
return TableStructureRecognizer.__html_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, True))
|
||||
|
||||
return TableStructureRecognizer.__desc_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, False), is_english)
|
||||
|
||||
@staticmethod
|
||||
def __html_table(cap, hdset, tbl):
|
||||
# constrcut HTML
|
||||
html = "<table>"
|
||||
if cap:
|
||||
html += f"<caption>{cap}</caption>"
|
||||
for i in range(len(tbl)):
|
||||
row = "<tr>"
|
||||
txts = []
|
||||
for j, arr in enumerate(tbl[i]):
|
||||
if arr is None:
|
||||
continue
|
||||
if not arr:
|
||||
row += "<td></td>" if i not in hdset else "<th></th>"
|
||||
continue
|
||||
txt = ""
|
||||
if arr:
|
||||
h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10)
|
||||
txt = " ".join([c["text"] for c in Recognizer.sort_Y_firstly(arr, h)])
|
||||
txts.append(txt)
|
||||
sp = ""
|
||||
if arr[0].get("colspan"):
|
||||
sp = "colspan={}".format(arr[0]["colspan"])
|
||||
if arr[0].get("rowspan"):
|
||||
sp += " rowspan={}".format(arr[0]["rowspan"])
|
||||
if i in hdset:
|
||||
row += f"<th {sp} >" + txt + "</th>"
|
||||
else:
|
||||
row += f"<td {sp} >" + txt + "</td>"
|
||||
|
||||
if i in hdset:
|
||||
if all([t in hdset for t in txts]):
|
||||
continue
|
||||
for t in txts:
|
||||
hdset.add(t)
|
||||
|
||||
if row != "<tr>":
|
||||
row += "</tr>"
|
||||
else:
|
||||
row = ""
|
||||
html += "\n" + row
|
||||
html += "\n</table>"
|
||||
return html
|
||||
|
||||
@staticmethod
|
||||
def __desc_table(cap, hdr_rowno, tbl, is_english):
|
||||
# get text of every colomn in header row to become header text
|
||||
clmno = len(tbl[0])
|
||||
rowno = len(tbl)
|
||||
headers = {}
|
||||
hdrset = set()
|
||||
lst_hdr = []
|
||||
de = "的" if not is_english else " for "
|
||||
for r in sorted(list(hdr_rowno)):
|
||||
headers[r] = ["" for _ in range(clmno)]
|
||||
for i in range(clmno):
|
||||
if not tbl[r][i]:
|
||||
continue
|
||||
txt = " ".join([a["text"].strip() for a in tbl[r][i]])
|
||||
headers[r][i] = txt
|
||||
hdrset.add(txt)
|
||||
if all([not t for t in headers[r]]):
|
||||
del headers[r]
|
||||
hdr_rowno.remove(r)
|
||||
continue
|
||||
for j in range(clmno):
|
||||
if headers[r][j]:
|
||||
continue
|
||||
if j >= len(lst_hdr):
|
||||
break
|
||||
headers[r][j] = lst_hdr[j]
|
||||
lst_hdr = headers[r]
|
||||
for i in range(rowno):
|
||||
if i not in hdr_rowno:
|
||||
continue
|
||||
for j in range(i + 1, rowno):
|
||||
if j not in hdr_rowno:
|
||||
break
|
||||
for k in range(clmno):
|
||||
if not headers[j - 1][k]:
|
||||
continue
|
||||
if headers[j][k].find(headers[j - 1][k]) >= 0:
|
||||
continue
|
||||
if len(headers[j][k]) > len(headers[j - 1][k]):
|
||||
headers[j][k] += (de if headers[j][k] else "") + headers[j - 1][k]
|
||||
else:
|
||||
headers[j][k] = headers[j - 1][k] + (de if headers[j - 1][k] else "") + headers[j][k]
|
||||
|
||||
logging.debug(f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
|
||||
row_txt = []
|
||||
for i in range(rowno):
|
||||
if i in hdr_rowno:
|
||||
continue
|
||||
rtxt = []
|
||||
|
||||
def append(delimer):
|
||||
nonlocal rtxt, row_txt
|
||||
rtxt = delimer.join(rtxt)
|
||||
if row_txt and len(row_txt[-1]) + len(rtxt) < 64:
|
||||
row_txt[-1] += "\n" + rtxt
|
||||
else:
|
||||
row_txt.append(rtxt)
|
||||
|
||||
r = 0
|
||||
if len(headers.items()):
|
||||
_arr = [(i - r, r) for r, _ in headers.items() if r < i]
|
||||
if _arr:
|
||||
_, r = min(_arr, key=lambda x: x[0])
|
||||
|
||||
if r not in headers and clmno <= 2:
|
||||
for j in range(clmno):
|
||||
if not tbl[i][j]:
|
||||
continue
|
||||
txt = "".join([a["text"].strip() for a in tbl[i][j]])
|
||||
if txt:
|
||||
rtxt.append(txt)
|
||||
if rtxt:
|
||||
append(":")
|
||||
continue
|
||||
|
||||
for j in range(clmno):
|
||||
if not tbl[i][j]:
|
||||
continue
|
||||
txt = "".join([a["text"].strip() for a in tbl[i][j]])
|
||||
if not txt:
|
||||
continue
|
||||
ctt = headers[r][j] if r in headers else ""
|
||||
if ctt:
|
||||
ctt += ":"
|
||||
ctt += txt
|
||||
if ctt:
|
||||
rtxt.append(ctt)
|
||||
|
||||
if rtxt:
|
||||
row_txt.append("; ".join(rtxt))
|
||||
|
||||
if cap:
|
||||
if is_english:
|
||||
from_ = " in "
|
||||
else:
|
||||
from_ = "来自"
|
||||
row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt]
|
||||
return row_txt
|
||||
|
||||
@staticmethod
|
||||
def __cal_spans(boxes, rows, cols, tbl, html=True):
|
||||
# caculate span
|
||||
clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) for cln in cols]
|
||||
crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln]) for cln in cols]
|
||||
rtop = [np.mean([c.get("R_top", c["top"]) for c in row]) for row in rows]
|
||||
rbtm = [np.mean([c.get("R_btm", c["bottom"]) for c in row]) for row in rows]
|
||||
for b in boxes:
|
||||
if "SP" not in b:
|
||||
continue
|
||||
b["colspan"] = [b["cn"]]
|
||||
b["rowspan"] = [b["rn"]]
|
||||
# col span
|
||||
for j in range(0, len(clft)):
|
||||
if j == b["cn"]:
|
||||
continue
|
||||
if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]:
|
||||
continue
|
||||
if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]:
|
||||
continue
|
||||
b["colspan"].append(j)
|
||||
# row span
|
||||
for j in range(0, len(rtop)):
|
||||
if j == b["rn"]:
|
||||
continue
|
||||
if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]:
|
||||
continue
|
||||
if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]:
|
||||
continue
|
||||
b["rowspan"].append(j)
|
||||
|
||||
def join(arr):
|
||||
if not arr:
|
||||
return ""
|
||||
return "".join([t["text"] for t in arr])
|
||||
|
||||
# rm the spaning cells
|
||||
for i in range(len(tbl)):
|
||||
for j, arr in enumerate(tbl[i]):
|
||||
if not arr:
|
||||
continue
|
||||
if all(["rowspan" not in a and "colspan" not in a for a in arr]):
|
||||
continue
|
||||
rowspan, colspan = [], []
|
||||
for a in arr:
|
||||
if isinstance(a.get("rowspan", 0), list):
|
||||
rowspan.extend(a["rowspan"])
|
||||
if isinstance(a.get("colspan", 0), list):
|
||||
colspan.extend(a["colspan"])
|
||||
rowspan, colspan = set(rowspan), set(colspan)
|
||||
if len(rowspan) < 2 and len(colspan) < 2:
|
||||
for a in arr:
|
||||
if "rowspan" in a:
|
||||
del a["rowspan"]
|
||||
if "colspan" in a:
|
||||
del a["colspan"]
|
||||
continue
|
||||
rowspan, colspan = sorted(rowspan), sorted(colspan)
|
||||
rowspan = list(range(rowspan[0], rowspan[-1] + 1))
|
||||
colspan = list(range(colspan[0], colspan[-1] + 1))
|
||||
assert i in rowspan, rowspan
|
||||
assert j in colspan, colspan
|
||||
arr = []
|
||||
for r in rowspan:
|
||||
for c in colspan:
|
||||
arr_txt = join(arr)
|
||||
if tbl[r][c] and join(tbl[r][c]) != arr_txt:
|
||||
arr.extend(tbl[r][c])
|
||||
tbl[r][c] = None if html else arr
|
||||
for a in arr:
|
||||
if len(rowspan) > 1:
|
||||
a["rowspan"] = len(rowspan)
|
||||
elif "rowspan" in a:
|
||||
del a["rowspan"]
|
||||
if len(colspan) > 1:
|
||||
a["colspan"] = len(colspan)
|
||||
elif "colspan" in a:
|
||||
del a["colspan"]
|
||||
tbl[rowspan[0]][colspan[0]] = arr
|
||||
|
||||
return tbl
|
||||
|
||||
def _run_ascend_tsr(self, image_list, thr=0.2, batch_size=16):
|
||||
import math
|
||||
|
||||
from ais_bench.infer.interface import InferSession
|
||||
|
||||
model_dir = os.path.join(get_project_base_directory(), "res/deepdoc")
|
||||
model_file_path = os.path.join(model_dir, "tsr.om")
|
||||
|
||||
if not os.path.exists(model_file_path):
|
||||
raise ValueError(f"Model file not found: {model_file_path}")
|
||||
|
||||
device_id = int(os.getenv("ASCEND_LAYOUT_RECOGNIZER_DEVICE_ID", 0))
|
||||
session = InferSession(device_id=device_id, model_path=model_file_path)
|
||||
|
||||
images = [np.array(im) if not isinstance(im, np.ndarray) else im for im in image_list]
|
||||
results = []
|
||||
|
||||
conf_thr = max(thr, 0.08)
|
||||
|
||||
batch_loop_cnt = math.ceil(float(len(images)) / batch_size)
|
||||
for bi in range(batch_loop_cnt):
|
||||
s = bi * batch_size
|
||||
e = min((bi + 1) * batch_size, len(images))
|
||||
batch_images = images[s:e]
|
||||
|
||||
inputs_list = self.preprocess(batch_images)
|
||||
for ins in inputs_list:
|
||||
feeds = []
|
||||
if "image" in ins:
|
||||
feeds.append(ins["image"])
|
||||
else:
|
||||
feeds.append(ins[self.input_names[0]])
|
||||
output_list = session.infer(feeds=feeds, mode="static")
|
||||
bb = self.postprocess(output_list, ins, conf_thr)
|
||||
results.append(bb)
|
||||
return results
|
||||
0
api/app/core/rag/graphrag/__init__.py
Normal file
0
api/app/core/rag/graphrag/__init__.py
Normal file
19
api/app/core/rag/graphrag/utils.py
Normal file
19
api/app/core/rag/graphrag/utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import xxhash
|
||||
from app.aioRedis import aio_redis_set, aio_redis_get
|
||||
|
||||
def get_llm_cache(llmnm, txt, history, genconf):
|
||||
hasher = xxhash.xxh64()
|
||||
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
|
||||
|
||||
k = hasher.hexdigest()
|
||||
bin = aio_redis_get(k)
|
||||
if not bin:
|
||||
return None
|
||||
return bin
|
||||
|
||||
|
||||
def set_llm_cache(llmnm, txt, v, history, genconf):
|
||||
hasher = xxhash.xxh64()
|
||||
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
|
||||
k = hasher.hexdigest()
|
||||
aio_redis_set(k, v.encode("utf-8"), 24 * 3600)
|
||||
0
api/app/core/rag/llm/__init__.py
Normal file
0
api/app/core/rag/llm/__init__.py
Normal file
670
api/app/core/rag/llm/chat_model.py
Normal file
670
api/app/core/rag/llm/chat_model.py
Normal file
@@ -0,0 +1,670 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from typing import Any, Protocol
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import json_repair
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from strenum import StrEnum
|
||||
|
||||
from app.core.rag.nlp import is_chinese, is_english
|
||||
from app.core.rag.common.token_utils import num_tokens_from_string, total_token_count_from_response
|
||||
|
||||
|
||||
# Error message constants
|
||||
class LLMErrorCode(StrEnum):
|
||||
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
|
||||
ERROR_AUTHENTICATION = "AUTH_ERROR"
|
||||
ERROR_INVALID_REQUEST = "INVALID_REQUEST"
|
||||
ERROR_SERVER = "SERVER_ERROR"
|
||||
ERROR_TIMEOUT = "TIMEOUT"
|
||||
ERROR_CONNECTION = "CONNECTION_ERROR"
|
||||
ERROR_MODEL = "MODEL_ERROR"
|
||||
ERROR_MAX_ROUNDS = "ERROR_MAX_ROUNDS"
|
||||
ERROR_CONTENT_FILTER = "CONTENT_FILTERED"
|
||||
ERROR_QUOTA = "QUOTA_EXCEEDED"
|
||||
ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED"
|
||||
ERROR_GENERIC = "GENERIC_ERROR"
|
||||
|
||||
|
||||
class ReActMode(StrEnum):
|
||||
FUNCTION_CALL = "function_call"
|
||||
REACT = "react"
|
||||
|
||||
|
||||
ERROR_PREFIX = "**ERROR**"
|
||||
LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。"
|
||||
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
|
||||
|
||||
|
||||
class ToolCallSession(Protocol):
|
||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
|
||||
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
|
||||
self.model_name = model_name
|
||||
# Configure retry parameters
|
||||
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
|
||||
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
|
||||
self.max_rounds = kwargs.get("max_rounds", 5)
|
||||
self.is_tools = False
|
||||
self.tools = []
|
||||
self.toolcall_sessions = {}
|
||||
|
||||
def _get_delay(self):
|
||||
"""Calculate retry delay time"""
|
||||
return self.base_delay * random.uniform(10, 150)
|
||||
|
||||
def _classify_error(self, error):
|
||||
"""Classify error based on error message content"""
|
||||
error_str = str(error).lower()
|
||||
|
||||
keywords_mapping = [
|
||||
(["quota", "capacity", "credit", "billing", "balance", "欠费"], LLMErrorCode.ERROR_QUOTA),
|
||||
(["rate limit", "429", "tpm limit", "too many requests", "requests per minute"], LLMErrorCode.ERROR_RATE_LIMIT),
|
||||
(["auth", "key", "apikey", "401", "forbidden", "permission"], LLMErrorCode.ERROR_AUTHENTICATION),
|
||||
(["invalid", "bad request", "400", "format", "malformed", "parameter"], LLMErrorCode.ERROR_INVALID_REQUEST),
|
||||
(["server", "503", "502", "504", "500", "unavailable"], LLMErrorCode.ERROR_SERVER),
|
||||
(["timeout", "timed out"], LLMErrorCode.ERROR_TIMEOUT),
|
||||
(["connect", "network", "unreachable", "dns"], LLMErrorCode.ERROR_CONNECTION),
|
||||
(["filter", "content", "policy", "blocked", "safety", "inappropriate"], LLMErrorCode.ERROR_CONTENT_FILTER),
|
||||
(["model", "not found", "does not exist", "not available"], LLMErrorCode.ERROR_MODEL),
|
||||
(["max rounds"], LLMErrorCode.ERROR_MODEL),
|
||||
]
|
||||
for words, code in keywords_mapping:
|
||||
if re.search("({})".format("|".join(words)), error_str):
|
||||
return code
|
||||
|
||||
return LLMErrorCode.ERROR_GENERIC
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
|
||||
allowed_conf = {
|
||||
"temperature",
|
||||
"max_completion_tokens",
|
||||
"top_p",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"functions",
|
||||
"function_call",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"extra_headers"
|
||||
}
|
||||
|
||||
gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf}
|
||||
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||
if self.model_name.lower().find("qwq") >= 0:
|
||||
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly")
|
||||
|
||||
final_ans = ""
|
||||
tol_token = 0
|
||||
for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
|
||||
if delta.startswith("<think>") or delta.endswith("</think>"):
|
||||
continue
|
||||
final_ans += delta
|
||||
tol_token = tol
|
||||
|
||||
if len(final_ans.strip()) == 0:
|
||||
final_ans = "**ERROR**: Empty response from reasoning model"
|
||||
|
||||
return final_ans.strip(), tol_token
|
||||
|
||||
if self.model_name.lower().find("qwen3") >= 0:
|
||||
kwargs["extra_body"] = {"enable_thinking": False}
|
||||
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
||||
|
||||
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
|
||||
return "", 0
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
ans = self._length_stop(ans)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||
reasoning_start = False
|
||||
|
||||
if kwargs.get("stop") or "stop" in gen_conf:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
|
||||
else:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
|
||||
for resp in response:
|
||||
if not resp.choices:
|
||||
continue
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
||||
ans = ""
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>"
|
||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
||||
else:
|
||||
reasoning_start = False
|
||||
ans = resp.choices[0].delta.content
|
||||
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
tol = num_tokens_from_string(resp.choices[0].delta.content)
|
||||
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
yield ans, tol
|
||||
|
||||
def _length_stop(self, ans):
|
||||
if is_chinese([ans]):
|
||||
return ans + LENGTH_NOTIFICATION_CN
|
||||
return ans + LENGTH_NOTIFICATION_EN
|
||||
|
||||
@property
|
||||
def _retryable_errors(self) -> set[str]:
|
||||
return {
|
||||
LLMErrorCode.ERROR_RATE_LIMIT,
|
||||
LLMErrorCode.ERROR_SERVER,
|
||||
}
|
||||
|
||||
def _should_retry(self, error_code: str) -> bool:
|
||||
return error_code in self._retryable_errors
|
||||
|
||||
def _exceptions(self, e, attempt) -> str | None:
|
||||
logging.exception("OpenAI chat_with_tools")
|
||||
# Classify the error
|
||||
error_code = self._classify_error(e)
|
||||
if attempt == self.max_retries:
|
||||
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
||||
|
||||
if self._should_retry(error_code):
|
||||
delay = self._get_delay()
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
time.sleep(delay)
|
||||
return None
|
||||
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||
|
||||
def _verbose_tool_use(self, name, args, res):
|
||||
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||
|
||||
def _append_history(self, hist, tool_call, tool_res):
|
||||
hist.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": tool_call.index,
|
||||
"id": tool_call.id,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
"type": "function",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
try:
|
||||
if isinstance(tool_res, dict):
|
||||
tool_res = json.dumps(tool_res, ensure_ascii=False)
|
||||
finally:
|
||||
hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)})
|
||||
return hist
|
||||
|
||||
def bind_tools(self, toolcall_session, tools):
|
||||
if not (toolcall_session and tools):
|
||||
return
|
||||
self.is_tools = True
|
||||
self.toolcall_session = toolcall_session
|
||||
self.tools = tools
|
||||
|
||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
hist = deepcopy(history)
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = hist
|
||||
try:
|
||||
for _ in range(self.max_rounds + 1):
|
||||
logging.info(f"{self.tools=}")
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
||||
tk_count += total_token_count_from_response(response)
|
||||
if any([not response.choices, not response.choices[0].message]):
|
||||
raise Exception(f"500 response structure error. Response: {response}")
|
||||
|
||||
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
|
||||
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
|
||||
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
|
||||
|
||||
ans += response.choices[0].message.content
|
||||
if response.choices[0].finish_reason == "length":
|
||||
ans = self._length_stop(ans)
|
||||
|
||||
return ans, tk_count
|
||||
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
logging.info(f"Response {tool_call=}")
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
ans += self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||
ans += self._verbose_tool_use(name, {}, str(e))
|
||||
|
||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||
response, token_count = self._chat(history, gen_conf)
|
||||
ans += response
|
||||
tk_count += token_count
|
||||
return ans, tk_count
|
||||
except Exception as e:
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
return e, tk_count
|
||||
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
def chat(self, system, history, gen_conf={}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return self._chat(history, gen_conf, **kwargs)
|
||||
except Exception as e:
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
return e, 0
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
def _wrap_toolcall_message(self, stream):
|
||||
final_tool_calls = {}
|
||||
|
||||
for chunk in stream:
|
||||
for tool_call in chunk.choices[0].delta.tool_calls or []:
|
||||
index = tool_call.index
|
||||
|
||||
if index not in final_tool_calls:
|
||||
final_tool_calls[index] = tool_call
|
||||
|
||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
||||
|
||||
return final_tool_calls
|
||||
|
||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
tools = self.tools
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
total_tokens = 0
|
||||
hist = deepcopy(history)
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = hist
|
||||
try:
|
||||
for _ in range(self.max_rounds + 1):
|
||||
reasoning_start = False
|
||||
logging.info(f"{tools=}")
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
|
||||
final_tool_calls = {}
|
||||
answer = ""
|
||||
for resp in response:
|
||||
if resp.choices[0].delta.tool_calls:
|
||||
for tool_call in resp.choices[0].delta.tool_calls or []:
|
||||
index = tool_call.index
|
||||
|
||||
if index not in final_tool_calls:
|
||||
if not tool_call.function.arguments:
|
||||
tool_call.function.arguments = ""
|
||||
final_tool_calls[index] = tool_call
|
||||
else:
|
||||
final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
|
||||
continue
|
||||
|
||||
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
||||
raise Exception("500 response structure error.")
|
||||
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
|
||||
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
||||
ans = ""
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>"
|
||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
||||
yield ans
|
||||
else:
|
||||
reasoning_start = False
|
||||
answer += resp.choices[0].delta.content
|
||||
yield resp.choices[0].delta.content
|
||||
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
total_tokens = tol
|
||||
|
||||
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
||||
if finish_reason == "length":
|
||||
yield self._length_stop("")
|
||||
|
||||
if answer:
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
for tool_call in final_tool_calls.values():
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
yield self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||
yield self._verbose_tool_use(name, {}, str(e))
|
||||
|
||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
for resp in response:
|
||||
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
||||
raise Exception("500 response structure error.")
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
continue
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
total_tokens = tol
|
||||
answer += resp.choices[0].delta.content
|
||||
yield resp.choices[0].delta.content
|
||||
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
yield e
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
|
||||
yield delta_ans
|
||||
total_tokens += tol
|
||||
except openai.APIError as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
def _calculate_dynamic_ctx(self, history):
|
||||
"""Calculate dynamic context window size"""
|
||||
|
||||
def count_tokens(text):
|
||||
"""Calculate token count for text"""
|
||||
# Simple calculation: 1 token per ASCII character
|
||||
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
||||
total = 0
|
||||
for char in text:
|
||||
if ord(char) < 128: # ASCII characters
|
||||
total += 1
|
||||
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
||||
total += 2
|
||||
return total
|
||||
|
||||
# Calculate total tokens for all messages
|
||||
total_tokens = 0
|
||||
for message in history:
|
||||
content = message.get("content", "")
|
||||
# Calculate content tokens
|
||||
content_tokens = count_tokens(content)
|
||||
# Add role marker token overhead
|
||||
role_tokens = 4
|
||||
total_tokens += content_tokens + role_tokens
|
||||
|
||||
# Apply 1.2x buffer ratio
|
||||
total_tokens_with_buffer = int(total_tokens * 1.2)
|
||||
|
||||
if total_tokens_with_buffer <= 8192:
|
||||
ctx_size = 8192
|
||||
else:
|
||||
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
|
||||
ctx_size = ctx_multiplier * 8192
|
||||
|
||||
return ctx_size
|
||||
|
||||
|
||||
class GptTurbo(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class XinferenceChat(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class HuggingFaceChat(Base):
|
||||
_FACTORY_NAME = "HuggingFace"
|
||||
|
||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
super().__init__(key, model_name.split("___")[0], base_url, **kwargs)
|
||||
|
||||
|
||||
class ModelScopeChat(Base):
|
||||
_FACTORY_NAME = "ModelScope"
|
||||
|
||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
super().__init__(key, model_name.split("___")[0], base_url, **kwargs)
|
||||
|
||||
|
||||
class AzureChat(Base):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
self.model_name = model_name
|
||||
|
||||
@property
|
||||
def _retryable_errors(self) -> set[str]:
|
||||
return {
|
||||
LLMErrorCode.ERROR_RATE_LIMIT,
|
||||
LLMErrorCode.ERROR_SERVER,
|
||||
LLMErrorCode.ERROR_QUOTA,
|
||||
}
|
||||
|
||||
|
||||
class BaiChuanChat(Base):
|
||||
_FACTORY_NAME = "BaiChuan"
|
||||
|
||||
def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.baichuan-ai.com/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _format_params(params):
|
||||
return {
|
||||
"temperature": params.get("temperature", 0.3),
|
||||
"top_p": params.get("top_p", 0.85),
|
||||
}
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
return {
|
||||
"temperature": gen_conf.get("temperature", 0.3),
|
||||
"top_p": gen_conf.get("top_p", 0.85),
|
||||
}
|
||||
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
|
||||
**gen_conf,
|
||||
)
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
if is_chinese([ans]):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
|
||||
stream=True,
|
||||
**self._format_params(gen_conf),
|
||||
)
|
||||
for resp in response:
|
||||
if not resp.choices:
|
||||
continue
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
ans = resp.choices[0].delta.content
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
total_tokens = tol
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
if is_chinese([ans]):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
yield ans
|
||||
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class LocalAIChat(Base):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key="empty", base_url=base_url)
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
|
||||
class VolcEngineChat(Base):
|
||||
_FACTORY_NAME = "VolcEngine"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs):
|
||||
"""
|
||||
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
|
||||
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
|
||||
model_name is for display only
|
||||
"""
|
||||
base_url = base_url if base_url else "https://ark.cn-beijing.volces.com/api/v3"
|
||||
ark_api_key = json.loads(key).get("ark_api_key", "")
|
||||
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
||||
super().__init__(ark_api_key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class OpenAI_APIChat(Base):
|
||||
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
||||
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
model_name = model_name.split("___")[0]
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class GPUStackChat(Base):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
470
api/app/core/rag/llm/cv_model.py
Normal file
470
api/app/core/rag/llm/cv_model.py
Normal file
@@ -0,0 +1,470 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import logging
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from app.core.rag.nlp import is_english
|
||||
from app.core.rag.prompts.generator import vision_llm_describe_prompt
|
||||
from app.core.rag.common.token_utils import num_tokens_from_string, total_token_count_from_response
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, **kwargs):
|
||||
# Configure retry parameters
|
||||
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
|
||||
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
|
||||
self.max_rounds = kwargs.get("max_rounds", 5)
|
||||
self.is_tools = False
|
||||
self.tools = []
|
||||
self.toolcall_sessions = {}
|
||||
self.extra_body = None
|
||||
|
||||
def describe(self, image):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def _form_history(self, system, history, images=None):
|
||||
hist = []
|
||||
if system:
|
||||
hist.append({"role": "system", "content": system})
|
||||
for h in history:
|
||||
if images and h["role"] == "user":
|
||||
h["content"] = self._image_prompt(h["content"], images)
|
||||
images = []
|
||||
hist.append(h)
|
||||
return hist
|
||||
|
||||
def _image_prompt(self, text, images):
|
||||
if not images:
|
||||
return text
|
||||
|
||||
if isinstance(images, str) or "bytes" in type(images).__name__:
|
||||
images = [images]
|
||||
|
||||
pmpt = [{"type": "text", "text": text}]
|
||||
for img in images:
|
||||
pmpt.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"
|
||||
}
|
||||
})
|
||||
return pmpt
|
||||
|
||||
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
stream=True,
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:
|
||||
continue
|
||||
delta = resp.choices[0].delta.content
|
||||
ans = delta
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
tk_count += resp.usage.total_tokens
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield tk_count
|
||||
|
||||
@staticmethod
|
||||
def image2base64_rawvalue(self, image):
|
||||
# Return a base64 string without data URL header
|
||||
if isinstance(image, bytes):
|
||||
b64 = base64.b64encode(image).decode("utf-8")
|
||||
return b64
|
||||
if isinstance(image, BytesIO):
|
||||
data = image.getvalue()
|
||||
b64 = base64.b64encode(data).decode("utf-8")
|
||||
return b64
|
||||
with BytesIO() as buffered:
|
||||
try:
|
||||
image.save(buffered, format="JPEG")
|
||||
except Exception:
|
||||
# reset buffer before saving PNG
|
||||
buffered.seek(0)
|
||||
buffered.truncate()
|
||||
image.save(buffered, format="PNG")
|
||||
data = buffered.getvalue()
|
||||
b64 = base64.b64encode(data).decode("utf-8")
|
||||
return b64
|
||||
|
||||
@staticmethod
|
||||
def image2base64(image):
|
||||
# Return a data URL with the correct MIME to avoid provider mismatches
|
||||
if isinstance(image, bytes):
|
||||
# Best-effort magic number sniffing
|
||||
mime = "image/png"
|
||||
if len(image) >= 2 and image[0] == 0xFF and image[1] == 0xD8:
|
||||
mime = "image/jpeg"
|
||||
b64 = base64.b64encode(image).decode("utf-8")
|
||||
return f"data:{mime};base64,{b64}"
|
||||
if isinstance(image, BytesIO):
|
||||
data = image.getvalue()
|
||||
mime = "image/png"
|
||||
if len(data) >= 2 and data[0] == 0xFF and data[1] == 0xD8:
|
||||
mime = "image/jpeg"
|
||||
b64 = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{mime};base64,{b64}"
|
||||
with BytesIO() as buffered:
|
||||
fmt = "jpeg"
|
||||
try:
|
||||
image.save(buffered, format="JPEG")
|
||||
except Exception:
|
||||
# reset buffer before saving PNG
|
||||
buffered.seek(0)
|
||||
buffered.truncate()
|
||||
image.save(buffered, format="PNG")
|
||||
fmt = "png"
|
||||
data = buffered.getvalue()
|
||||
b64 = base64.b64encode(data).decode("utf-8")
|
||||
mime = f"image/{fmt}"
|
||||
return f"data:{mime};base64,{b64}"
|
||||
|
||||
def prompt(self, b64):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._image_prompt(
|
||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
b64
|
||||
)
|
||||
}
|
||||
]
|
||||
|
||||
def vision_llm_prompt(self, b64, prompt=None):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64)
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class GptV4(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.api_key = key
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self.prompt(b64),
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self.vision_llm_prompt(b64, prompt),
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
|
||||
|
||||
|
||||
class AzureGptV4(GptV4):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class QWenCV(GptV4):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", base_url=None, **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
|
||||
|
||||
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
|
||||
if video_bytes:
|
||||
try:
|
||||
summary, summary_num_tokens = self._process_video(video_bytes, filename)
|
||||
return summary, summary_num_tokens
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
return "**ERROR**: Method chat not supported yet.", 0
|
||||
|
||||
def _process_video(self, video_bytes, filename):
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
video_suffix = Path(filename).suffix or ".mp4"
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=video_suffix) as tmp:
|
||||
tmp.write(video_bytes)
|
||||
tmp_path = tmp.name
|
||||
|
||||
video_path = f"file://{tmp_path}"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"video": video_path,
|
||||
"fps": 2,
|
||||
},
|
||||
{
|
||||
"text": "Please summarize this video in proper sentences.",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def call_api():
|
||||
response = MultiModalConversation.call(
|
||||
api_key=self.api_key,
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
)
|
||||
summary = response["output"]["choices"][0]["message"].content[0]["text"]
|
||||
return summary, num_tokens_from_string(summary)
|
||||
|
||||
try:
|
||||
return call_api()
|
||||
except Exception as e1:
|
||||
import dashscope
|
||||
|
||||
dashscope.base_http_api_url = "https://dashscope-intl.aliyuncs.com/api/v1"
|
||||
try:
|
||||
return call_api()
|
||||
except Exception as e2:
|
||||
raise RuntimeError(f"Both default and intl endpoint failed.\nFirst error: {e1}\nSecond error: {e2}")
|
||||
|
||||
|
||||
class XinferenceCV(GptV4):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="", lang="Chinese", base_url="", **kwargs):
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class GPUStackCV(GptV4):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class OllamaCV(Base):
|
||||
_FACTORY_NAME = "Ollama"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
from ollama import Client
|
||||
self.client = Client(host=kwargs["base_url"])
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
def _clean_img(self, img):
|
||||
if not isinstance(img, str):
|
||||
return img
|
||||
|
||||
#remove the header like "data/*;base64,"
|
||||
if img.startswith("data:") and ";base64," in img:
|
||||
img = img.split(";base64,")[1]
|
||||
return img
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
options = {}
|
||||
if "temperature" in gen_conf:
|
||||
options["temperature"] = gen_conf["temperature"]
|
||||
if "top_p" in gen_conf:
|
||||
options["top_k"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
return options
|
||||
|
||||
def _form_history(self, system, history, images=None):
|
||||
hist = deepcopy(history)
|
||||
if system and hist[0]["role"] == "user":
|
||||
hist.insert(0, {"role": "system", "content": system})
|
||||
if not images:
|
||||
return hist
|
||||
temp_images = []
|
||||
for img in images:
|
||||
temp_images.append(self._clean_img(img))
|
||||
for his in hist:
|
||||
if his["role"] == "user":
|
||||
his["images"] = temp_images
|
||||
break
|
||||
return hist
|
||||
|
||||
def describe(self, image):
|
||||
prompt = self.prompt("")
|
||||
try:
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=prompt[0]["content"],
|
||||
images=[image],
|
||||
)
|
||||
ans = response["response"].strip()
|
||||
return ans, 128
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
|
||||
try:
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=vision_prompt[0]["content"],
|
||||
images=[image],
|
||||
)
|
||||
ans = response["response"].strip()
|
||||
return ans, 128
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||
try:
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
options=self._clean_conf(gen_conf),
|
||||
keep_alive=self.keep_alive
|
||||
)
|
||||
|
||||
ans = response["message"]["content"].strip()
|
||||
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
stream=True,
|
||||
options=self._clean_conf(gen_conf),
|
||||
keep_alive=self.keep_alive
|
||||
)
|
||||
for resp in response:
|
||||
if resp["done"]:
|
||||
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||
ans = resp["message"]["content"]
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
yield 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# import sys
|
||||
# chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
|
||||
# # 准备配置vision_model信息
|
||||
# azure_config = {
|
||||
# "api_key": "xxxxx",
|
||||
# "api_version": "2024-02-01"
|
||||
# }
|
||||
# # 转换为 JSON 字符串,因为类中使用 json.loads(key) 解析
|
||||
# key = json.dumps(azure_config)
|
||||
# # 初始化 AzureGptV4
|
||||
# vision_model = AzureGptV4(
|
||||
# key=key, # JSON 字符串形式的配置
|
||||
# model_name="gpt-4o",
|
||||
# lang="Chinese", # 默认使用中文
|
||||
# base_url="https://fosun-openai-east-us-001.openai.azure.com/" # Azure OpenAI 端点
|
||||
# )
|
||||
# try:
|
||||
# # 测试图像描述功能
|
||||
# image_path = "/Users/sbtjfdn/Downloads/记忆科学/files/aippt.cn.png"
|
||||
# with open(image_path, "rb") as image_file:
|
||||
# image_data = image_file.read()
|
||||
#
|
||||
# # 使用 describe 方法
|
||||
# description, token_count = vision_model.describe(image_data)
|
||||
# # from app.core.rag.prompts.generator import vision_llm_figure_describe_prompt
|
||||
# # description, token_count = vision_model.describe_with_prompt(image_data, prompt=vision_llm_figure_describe_prompt())
|
||||
# print(f"描述: {description}")
|
||||
# print(f"使用的令牌数: {token_count}")
|
||||
#
|
||||
# except Exception as e:
|
||||
# print(f"初始化或处理过程中出错: {str(e)}")
|
||||
|
||||
# 准备配置vision_model信息
|
||||
# 初始化 QWenCV
|
||||
vision_model = QWenCV(
|
||||
key="sk-8e9e40cd171749858ce2d3722ea75669",
|
||||
model_name="qwen-vl-max",
|
||||
lang="Chinese", # 默认使用中文
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
)
|
||||
try:
|
||||
# 测试图像描述功能
|
||||
image_path = "/Users/sbtjfdn/Downloads/记忆科学/files/10.png"
|
||||
with open(image_path, "rb") as image_file:
|
||||
image_data = image_file.read()
|
||||
|
||||
# 使用 describe 方法
|
||||
description, token_count = vision_model.describe(image_data)
|
||||
# from app.core.rag.prompts.generator import vision_llm_figure_describe_prompt
|
||||
# description, token_count = vision_model.describe_with_prompt(image_data, prompt=vision_llm_figure_describe_prompt())
|
||||
print(f"描述: {description}")
|
||||
print(f"使用的令牌数: {token_count}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"初始化或处理过程中出错: {str(e)}")
|
||||
179
api/app/core/rag/llm/sequence2txt_model.py
Normal file
179
api/app/core/rag/llm/sequence2txt_model.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC
|
||||
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
|
||||
from app.core.rag.common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
Abstract base class constructor.
|
||||
Parameters are not stored; initialization is left to subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
def transcription(self, audio_path, **kwargs):
|
||||
audio_file = open(audio_path, "rb")
|
||||
transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio_file)
|
||||
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
|
||||
|
||||
def audio2base64(self, audio):
|
||||
if isinstance(audio, bytes):
|
||||
return base64.b64encode(audio).decode("utf-8")
|
||||
if isinstance(audio, io.BytesIO):
|
||||
return base64.b64encode(audio.getvalue()).decode("utf-8")
|
||||
raise TypeError("The input audio file should be in binary format.")
|
||||
|
||||
|
||||
class GPTSeq2txt(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class QWenSeq2txt(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="qwen-audio-asr", **kwargs):
|
||||
import dashscope
|
||||
|
||||
dashscope.api_key = key
|
||||
self.model_name = model_name
|
||||
|
||||
def transcription(self, audio_path):
|
||||
if "paraformer" in self.model_name or "sensevoice" in self.model_name:
|
||||
return f"**ERROR**: model {self.model_name} is not suppported yet.", 0
|
||||
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
audio_path = f"file://{audio_path}"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"audio": audio_path}],
|
||||
}
|
||||
]
|
||||
|
||||
response = None
|
||||
full_content = ""
|
||||
try:
|
||||
response = MultiModalConversation.call(model="qwen-audio-asr", messages=messages, result_format="message", stream=True)
|
||||
for response in response:
|
||||
try:
|
||||
full_content += response["output"]["choices"][0]["message"].content[0]["text"]
|
||||
except Exception:
|
||||
pass
|
||||
return full_content, num_tokens_from_string(full_content)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
|
||||
class AzureSeq2txt(Base):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
||||
|
||||
class XinferenceSeq2txt(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="whisper-small", **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.model_name = model_name
|
||||
self.key = key
|
||||
|
||||
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
|
||||
if isinstance(audio, str):
|
||||
audio_file = open(audio, "rb")
|
||||
audio_data = audio_file.read()
|
||||
audio_file_name = audio.split("/")[-1]
|
||||
else:
|
||||
audio_data = audio
|
||||
audio_file_name = "audio.wav"
|
||||
|
||||
payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}
|
||||
|
||||
files = {"file": (audio_file_name, audio_data, "audio/wav")}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if "text" in result:
|
||||
transcription_text = result["text"].strip()
|
||||
return transcription_text, num_tokens_from_string(transcription_text)
|
||||
else:
|
||||
return "**ERROR**: Failed to retrieve transcription.", 0
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"**ERROR**: {str(e)}", 0
|
||||
|
||||
|
||||
class GPUStackSeq2txt(Base):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
if base_url.split("/")[-1] != "v1":
|
||||
base_url = os.path.join(base_url, "v1")
|
||||
self.base_url = base_url
|
||||
self.model_name = model_name
|
||||
self.key = key
|
||||
|
||||
|
||||
class ZhipuSeq2txt(Base):
|
||||
_FACTORY_NAME = "ZHIPU-AI"
|
||||
|
||||
def __init__(self, key, model_name="glm-asr", base_url="https://open.bigmodel.cn/api/paas/v4", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://open.bigmodel.cn/api/paas/v4"
|
||||
self.base_url = base_url
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.gen_conf = kwargs.get("gen_conf", {})
|
||||
self.stream = kwargs.get("stream", False)
|
||||
|
||||
def transcription(self, audio_path):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"temperature": str(self.gen_conf.get("temperature", 0.75)) or "0.75",
|
||||
"stream": self.stream,
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
with open(audio_path, "rb") as audio_file:
|
||||
files = {"file": audio_file}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url=f"{self.base_url}/audio/transcriptions",
|
||||
data=payload,
|
||||
files=files,
|
||||
headers=headers,
|
||||
)
|
||||
body = response.json()
|
||||
if response.status_code == 200:
|
||||
full_content = body["text"]
|
||||
return full_content, num_tokens_from_string(full_content)
|
||||
else:
|
||||
error = body["error"]
|
||||
return f"**ERROR**: code: {error['code']}, message: {error['message']}", 0
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
0
api/app/core/rag/models/__init__.py
Normal file
0
api/app/core/rag/models/__init__.py
Normal file
72
api/app/core/rag/models/chunk.py
Normal file
72
api/app/core/rag/models/chunk.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ChildDocumentChunk(BaseModel):
|
||||
"""Class for storing a piece of text and associated metadata."""
|
||||
|
||||
page_content: str
|
||||
|
||||
vector: list[float] | None = None
|
||||
|
||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||
documents, etc.).
|
||||
"""
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class DocumentChunk(BaseModel):
|
||||
"""Class for storing a piece of text and associated metadata."""
|
||||
|
||||
page_content: str
|
||||
|
||||
vector: list[float] | None = None
|
||||
|
||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||
documents, etc.).
|
||||
"""
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
|
||||
children: list[ChildDocumentChunk] | None = None
|
||||
|
||||
|
||||
class GeneralStructureChunk(BaseModel):
|
||||
"""
|
||||
General Structure Chunk.
|
||||
"""
|
||||
|
||||
general_chunks: list[str]
|
||||
|
||||
|
||||
class ParentChildChunk(BaseModel):
|
||||
"""
|
||||
Parent Child Chunk.
|
||||
"""
|
||||
|
||||
parent_content: str
|
||||
child_contents: list[str]
|
||||
|
||||
|
||||
class ParentChildStructureChunk(BaseModel):
|
||||
"""
|
||||
Parent Child Structure Chunk.
|
||||
"""
|
||||
|
||||
parent_child_chunks: list[ParentChildChunk]
|
||||
parent_mode: str = "paragraph"
|
||||
|
||||
|
||||
class QAChunk(BaseModel):
|
||||
"""
|
||||
QA Chunk.
|
||||
"""
|
||||
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class QAStructureChunk(BaseModel):
|
||||
"""
|
||||
QAStructureChunk.
|
||||
"""
|
||||
|
||||
qa_chunks: list[QAChunk]
|
||||
857
api/app/core/rag/nlp/__init__.py
Normal file
857
api/app/core/rag/nlp/__init__.py
Normal file
@@ -0,0 +1,857 @@
|
||||
import logging
|
||||
import random
|
||||
from collections import Counter
|
||||
|
||||
from app.core.rag.common.token_utils import num_tokens_from_string
|
||||
from . import rag_tokenizer
|
||||
import re
|
||||
import copy
|
||||
import roman_numbers as r
|
||||
from word2number import w2n
|
||||
from cn2an import cn2an
|
||||
from PIL import Image
|
||||
|
||||
import chardet
|
||||
|
||||
all_codecs = [
|
||||
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
|
||||
'cp037', 'cp273', 'cp424', 'cp437',
|
||||
'cp500', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp856', 'cp857',
|
||||
'cp858', 'cp860', 'cp861', 'cp862', 'cp863', 'cp864', 'cp865', 'cp866', 'cp869',
|
||||
'cp874', 'cp875', 'cp932', 'cp949', 'cp950', 'cp1006', 'cp1026', 'cp1125',
|
||||
'cp1140', 'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', 'cp1256',
|
||||
'cp1257', 'cp1258', 'euc_jp', 'euc_jis_2004', 'euc_jisx0213', 'euc_kr',
|
||||
'gb18030', 'hz', 'iso2022_jp', 'iso2022_jp_1', 'iso2022_jp_2',
|
||||
'iso2022_jp_2004', 'iso2022_jp_3', 'iso2022_jp_ext', 'iso2022_kr', 'latin_1',
|
||||
'iso8859_2', 'iso8859_3', 'iso8859_4', 'iso8859_5', 'iso8859_6', 'iso8859_7',
|
||||
'iso8859_8', 'iso8859_9', 'iso8859_10', 'iso8859_11', 'iso8859_13',
|
||||
'iso8859_14', 'iso8859_15', 'iso8859_16', 'johab', 'koi8_r', 'koi8_t', 'koi8_u',
|
||||
'kz1048', 'mac_cyrillic', 'mac_greek', 'mac_iceland', 'mac_latin2', 'mac_roman',
|
||||
'mac_turkish', 'ptcp154', 'shift_jis', 'shift_jis_2004', 'shift_jisx0213',
|
||||
'utf_32', 'utf_32_be', 'utf_32_le', 'utf_16_be', 'utf_16_le', 'utf_7', 'windows-1250', 'windows-1251',
|
||||
'windows-1252', 'windows-1253', 'windows-1254', 'windows-1255', 'windows-1256',
|
||||
'windows-1257', 'windows-1258', 'latin-2'
|
||||
]
|
||||
|
||||
|
||||
def find_codec(blob):
|
||||
detected = chardet.detect(blob[:1024])
|
||||
if detected['confidence'] > 0.5:
|
||||
if detected['encoding'] == "ascii":
|
||||
return "utf-8"
|
||||
|
||||
for c in all_codecs:
|
||||
try:
|
||||
blob[:1024].decode(c)
|
||||
return c
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
blob.decode(c)
|
||||
return c
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "utf-8"
|
||||
|
||||
|
||||
QUESTION_PATTERN = [
|
||||
r"第([零一二三四五六七八九十百0-9]+)问",
|
||||
r"第([零一二三四五六七八九十百0-9]+)条",
|
||||
r"[\((]([零一二三四五六七八九十百]+)[\))]",
|
||||
r"第([0-9]+)问",
|
||||
r"第([0-9]+)条",
|
||||
r"([0-9]{1,2})[\. 、]",
|
||||
r"([零一二三四五六七八九十百]+)[ 、]",
|
||||
r"[\((]([0-9]{1,2})[\))]",
|
||||
r"QUESTION (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
|
||||
r"QUESTION (I+V?|VI*|XI|IX|X)",
|
||||
r"QUESTION ([0-9]+)",
|
||||
]
|
||||
|
||||
|
||||
def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list):
|
||||
section, last_section = box['text'], last_box['text']
|
||||
q_reg = r'(\w|\W)*?(?:?|\?|\n|$)+'
|
||||
full_reg = reg + q_reg
|
||||
has_bull = re.match(full_reg, section)
|
||||
index_str = None
|
||||
if has_bull:
|
||||
if 'x0' not in last_box:
|
||||
last_box['x0'] = box['x0']
|
||||
if 'top' not in last_box:
|
||||
last_box['top'] = box['top']
|
||||
if last_bull and box['x0'] - last_box['x0'] > 10:
|
||||
return None, last_index
|
||||
if not last_bull and box['x0'] >= last_box['x0'] and box['top'] - last_box['top'] < 20:
|
||||
return None, last_index
|
||||
avg_bull_x0 = 0
|
||||
if bull_x0_list:
|
||||
avg_bull_x0 = sum(bull_x0_list) / len(bull_x0_list)
|
||||
else:
|
||||
avg_bull_x0 = box['x0']
|
||||
if box['x0'] - avg_bull_x0 > 10:
|
||||
return None, last_index
|
||||
index_str = has_bull.group(1)
|
||||
index = index_int(index_str)
|
||||
if last_section[-1] == ':' or last_section[-1] == ':':
|
||||
return None, last_index
|
||||
if not last_index or index >= last_index:
|
||||
bull_x0_list.append(box['x0'])
|
||||
return has_bull, index
|
||||
if section[-1] == '?' or section[-1] == '?':
|
||||
bull_x0_list.append(box['x0'])
|
||||
return has_bull, index
|
||||
if box['layout_type'] == 'title':
|
||||
bull_x0_list.append(box['x0'])
|
||||
return has_bull, index
|
||||
pure_section = section.lstrip(re.match(reg, section).group()).lower()
|
||||
ask_reg = r'(what|when|where|how|why|which|who|whose|为什么|为啥|哪)'
|
||||
if re.match(ask_reg, pure_section):
|
||||
bull_x0_list.append(box['x0'])
|
||||
return has_bull, index
|
||||
return None, last_index
|
||||
|
||||
|
||||
def index_int(index_str):
|
||||
res = -1
|
||||
try:
|
||||
res = int(index_str)
|
||||
except ValueError:
|
||||
try:
|
||||
res = w2n.word_to_num(index_str)
|
||||
except ValueError:
|
||||
try:
|
||||
res = cn2an(index_str)
|
||||
except ValueError:
|
||||
try:
|
||||
res = r.number(index_str)
|
||||
except ValueError:
|
||||
return -1
|
||||
return res
|
||||
|
||||
|
||||
def qbullets_category(sections):
|
||||
global QUESTION_PATTERN
|
||||
hits = [0] * len(QUESTION_PATTERN)
|
||||
for i, pro in enumerate(QUESTION_PATTERN):
|
||||
for sec in sections:
|
||||
if re.match(pro, sec) and not not_bullet(sec):
|
||||
hits[i] += 1
|
||||
break
|
||||
maxium = 0
|
||||
res = -1
|
||||
for i, h in enumerate(hits):
|
||||
if h <= maxium:
|
||||
continue
|
||||
res = i
|
||||
maxium = h
|
||||
return res, QUESTION_PATTERN[res]
|
||||
|
||||
|
||||
BULLET_PATTERN = [[
|
||||
r"第[零一二三四五六七八九十百0-9]+(分?编|部分)",
|
||||
r"第[零一二三四五六七八九十百0-9]+章",
|
||||
r"第[零一二三四五六七八九十百0-9]+节",
|
||||
r"第[零一二三四五六七八九十百0-9]+条",
|
||||
r"[\((][零一二三四五六七八九十百]+[\))]",
|
||||
], [
|
||||
r"第[0-9]+章",
|
||||
r"第[0-9]+节",
|
||||
r"[0-9]{,2}[\. 、]",
|
||||
r"[0-9]{,2}\.[0-9]{,2}[^a-zA-Z/%~-]",
|
||||
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
|
||||
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
|
||||
], [
|
||||
r"第[零一二三四五六七八九十百0-9]+章",
|
||||
r"第[零一二三四五六七八九十百0-9]+节",
|
||||
r"[零一二三四五六七八九十百]+[ 、]",
|
||||
r"[\((][零一二三四五六七八九十百]+[\))]",
|
||||
r"[\((][0-9]{,2}[\))]",
|
||||
], [
|
||||
r"PART (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
|
||||
r"Chapter (I+V?|VI*|XI|IX|X)",
|
||||
r"Section [0-9]+",
|
||||
r"Article [0-9]+"
|
||||
], [
|
||||
r"^#[^#]",
|
||||
r"^##[^#]",
|
||||
r"^###.*",
|
||||
r"^####.*",
|
||||
r"^#####.*",
|
||||
r"^######.*",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def random_choices(arr, k):
|
||||
k = min(len(arr), k)
|
||||
return random.choices(arr, k=k)
|
||||
|
||||
|
||||
def not_bullet(line):
|
||||
patt = [
|
||||
r"0", r"[0-9]+ +[0-9~个只-]", r"[0-9]+\.{2,}"
|
||||
]
|
||||
return any([re.match(r, line) for r in patt])
|
||||
|
||||
|
||||
def bullets_category(sections):
|
||||
global BULLET_PATTERN
|
||||
hits = [0] * len(BULLET_PATTERN)
|
||||
for i, pro in enumerate(BULLET_PATTERN):
|
||||
for sec in sections:
|
||||
sec = sec.strip()
|
||||
for p in pro:
|
||||
if re.match(p, sec) and not not_bullet(sec):
|
||||
hits[i] += 1
|
||||
break
|
||||
maxium = 0
|
||||
res = -1
|
||||
for i, h in enumerate(hits):
|
||||
if h <= maxium:
|
||||
continue
|
||||
res = i
|
||||
maxium = h
|
||||
return res
|
||||
|
||||
|
||||
def is_english(texts):
|
||||
if not texts:
|
||||
return False
|
||||
|
||||
pattern = re.compile(r"[`a-zA-Z0-9\s.,':;/\"?<>!\(\)\-]")
|
||||
|
||||
if isinstance(texts, str):
|
||||
texts = list(texts)
|
||||
elif isinstance(texts, list):
|
||||
texts = [t for t in texts if isinstance(t, str) and t.strip()]
|
||||
else:
|
||||
return False
|
||||
|
||||
if not texts:
|
||||
return False
|
||||
|
||||
eng = sum(1 for t in texts if pattern.fullmatch(t.strip()))
|
||||
return (eng / len(texts)) > 0.8
|
||||
|
||||
|
||||
def is_chinese(text):
|
||||
if not text:
|
||||
return False
|
||||
chinese = 0
|
||||
for ch in text:
|
||||
if '\u4e00' <= ch <= '\u9fff':
|
||||
chinese += 1
|
||||
if chinese / len(text) > 0.2:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def tokenize(d, t, eng):
|
||||
d["content_with_weight"] = t
|
||||
t = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", t)
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(t)
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
|
||||
|
||||
def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
|
||||
res = []
|
||||
# wrap up as es documents
|
||||
for ii, ck in enumerate(chunks):
|
||||
if len(ck.strip()) == 0:
|
||||
continue
|
||||
logging.debug("-- {}".format(ck))
|
||||
d = copy.deepcopy(doc)
|
||||
if pdf_parser:
|
||||
try:
|
||||
d["image"], poss = pdf_parser.crop(ck, need_position=True)
|
||||
add_positions(d, poss)
|
||||
ck = pdf_parser.remove_tag(ck)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
add_positions(d, [[ii]*5])
|
||||
tokenize(d, ck, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def tokenize_chunks_with_images(chunks, doc, eng, images):
|
||||
res = []
|
||||
# wrap up as es documents
|
||||
for ii, (ck, image) in enumerate(zip(chunks, images)):
|
||||
if len(ck.strip()) == 0:
|
||||
continue
|
||||
logging.debug("-- {}".format(ck))
|
||||
d = copy.deepcopy(doc)
|
||||
d["image"] = image
|
||||
add_positions(d, [[ii]*5])
|
||||
tokenize(d, ck, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def tokenize_table(tbls, doc, eng, batch_size=10):
|
||||
res = []
|
||||
# add tables
|
||||
for (img, rows), poss in tbls:
|
||||
if not rows:
|
||||
continue
|
||||
if isinstance(rows, str):
|
||||
d = copy.deepcopy(doc)
|
||||
tokenize(d, rows, eng)
|
||||
d["content_with_weight"] = rows
|
||||
if img:
|
||||
d["image"] = img
|
||||
d["doc_type_kwd"] = "image"
|
||||
if poss:
|
||||
add_positions(d, poss)
|
||||
res.append(d)
|
||||
continue
|
||||
de = "; " if eng else "; "
|
||||
for i in range(0, len(rows), batch_size):
|
||||
d = copy.deepcopy(doc)
|
||||
r = de.join(rows[i:i + batch_size])
|
||||
tokenize(d, r, eng)
|
||||
if img:
|
||||
d["image"] = img
|
||||
d["doc_type_kwd"] = "image"
|
||||
add_positions(d, poss)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def add_positions(d, poss):
|
||||
if not poss:
|
||||
return
|
||||
page_num_int = []
|
||||
position_int = []
|
||||
top_int = []
|
||||
for pn, left, right, top, bottom in poss:
|
||||
page_num_int.append(int(pn + 1))
|
||||
top_int.append(int(top))
|
||||
position_int.append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
|
||||
d["page_num_int"] = page_num_int
|
||||
d["position_int"] = position_int
|
||||
d["top_int"] = top_int
|
||||
|
||||
|
||||
def remove_contents_table(sections, eng=False):
|
||||
i = 0
|
||||
while i < len(sections):
|
||||
def get(i):
|
||||
nonlocal sections
|
||||
return (sections[i] if isinstance(sections[i],
|
||||
type("")) else sections[i][0]).strip()
|
||||
|
||||
if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$",
|
||||
re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], flags=re.IGNORECASE)):
|
||||
i += 1
|
||||
continue
|
||||
sections.pop(i)
|
||||
if i >= len(sections):
|
||||
break
|
||||
prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2])
|
||||
while not prefix:
|
||||
sections.pop(i)
|
||||
if i >= len(sections):
|
||||
break
|
||||
prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2])
|
||||
sections.pop(i)
|
||||
if i >= len(sections) or not prefix:
|
||||
break
|
||||
for j in range(i, min(i + 128, len(sections))):
|
||||
if not re.match(prefix, get(j)):
|
||||
continue
|
||||
for _ in range(i, j):
|
||||
sections.pop(i)
|
||||
break
|
||||
|
||||
|
||||
def make_colon_as_title(sections):
|
||||
if not sections:
|
||||
return []
|
||||
if isinstance(sections[0], type("")):
|
||||
return sections
|
||||
i = 0
|
||||
while i < len(sections):
|
||||
txt, layout = sections[i]
|
||||
i += 1
|
||||
txt = txt.split("@")[0].strip()
|
||||
if not txt:
|
||||
continue
|
||||
if txt[-1] not in "::":
|
||||
continue
|
||||
txt = txt[::-1]
|
||||
arr = re.split(r"([。?!!?;;]| \.)", txt)
|
||||
if len(arr) < 2 or len(arr[1]) < 32:
|
||||
continue
|
||||
sections.insert(i - 1, (arr[0][::-1], "title"))
|
||||
i += 1
|
||||
|
||||
|
||||
def title_frequency(bull, sections):
|
||||
bullets_size = len(BULLET_PATTERN[bull])
|
||||
levels = [bullets_size + 1 for _ in range(len(sections))]
|
||||
if not sections or bull < 0:
|
||||
return bullets_size + 1, levels
|
||||
|
||||
for i, (txt, layout) in enumerate(sections):
|
||||
for j, p in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(p, txt.strip()) and not not_bullet(txt):
|
||||
levels[i] = j
|
||||
break
|
||||
else:
|
||||
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]):
|
||||
levels[i] = bullets_size
|
||||
most_level = bullets_size + 1
|
||||
for level, c in sorted(Counter(levels).items(), key=lambda x: x[1] * -1):
|
||||
if level <= bullets_size:
|
||||
most_level = level
|
||||
break
|
||||
return most_level, levels
|
||||
|
||||
|
||||
def not_title(txt):
|
||||
if re.match(r"第[零一二三四五六七八九十百0-9]+条", txt):
|
||||
return False
|
||||
if len(txt.split()) > 12 or (txt.find(" ") < 0 and len(txt) >= 32):
|
||||
return True
|
||||
return re.search(r"[,;,。;!!]", txt)
|
||||
|
||||
def tree_merge(bull, sections, depth):
|
||||
|
||||
if not sections or bull < 0:
|
||||
return sections
|
||||
if isinstance(sections[0], type("")):
|
||||
sections = [(s, "") for s in sections]
|
||||
|
||||
# filter out position information in pdf sections
|
||||
sections = [(t, o) for t, o in sections if
|
||||
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())]
|
||||
|
||||
def get_level(bull, section):
|
||||
text, layout = section
|
||||
text = re.sub(r"\u3000", " ", text).strip()
|
||||
|
||||
for i, title in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(title, text.strip()):
|
||||
return i+1, text
|
||||
else:
|
||||
if re.search(r"(title|head)", layout) and not not_title(text):
|
||||
return len(BULLET_PATTERN[bull])+1, text
|
||||
else:
|
||||
return len(BULLET_PATTERN[bull])+2, text
|
||||
level_set = set()
|
||||
lines = []
|
||||
for section in sections:
|
||||
level, text = get_level(bull, section)
|
||||
if not text.strip("\n"):
|
||||
continue
|
||||
|
||||
lines.append((level, text))
|
||||
level_set.add(level)
|
||||
|
||||
sorted_levels = sorted(list(level_set))
|
||||
|
||||
if depth <= len(sorted_levels):
|
||||
target_level = sorted_levels[depth - 1]
|
||||
else:
|
||||
target_level = sorted_levels[-1]
|
||||
|
||||
if target_level == len(BULLET_PATTERN[bull]) + 2:
|
||||
target_level = sorted_levels[-2] if len(sorted_levels) > 1 else sorted_levels[0]
|
||||
|
||||
root = Node(level=0, depth=target_level, texts=[])
|
||||
root.build_tree(lines)
|
||||
|
||||
return [("\n").join(element) for element in root.get_tree() if element]
|
||||
|
||||
def hierarchical_merge(bull, sections, depth):
|
||||
|
||||
if not sections or bull < 0:
|
||||
return []
|
||||
if isinstance(sections[0], type("")):
|
||||
sections = [(s, "") for s in sections]
|
||||
sections = [(t, o) for t, o in sections if
|
||||
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())]
|
||||
bullets_size = len(BULLET_PATTERN[bull])
|
||||
levels = [[] for _ in range(bullets_size + 2)]
|
||||
|
||||
for i, (txt, layout) in enumerate(sections):
|
||||
for j, p in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(p, txt.strip()):
|
||||
levels[j].append(i)
|
||||
break
|
||||
else:
|
||||
if re.search(r"(title|head)", layout) and not not_title(txt):
|
||||
levels[bullets_size].append(i)
|
||||
else:
|
||||
levels[bullets_size + 1].append(i)
|
||||
sections = [t for t, _ in sections]
|
||||
|
||||
# for s in sections: print("--", s)
|
||||
|
||||
def binary_search(arr, target):
|
||||
if not arr:
|
||||
return -1
|
||||
if target > arr[-1]:
|
||||
return len(arr) - 1
|
||||
if target < arr[0]:
|
||||
return -1
|
||||
s, e = 0, len(arr)
|
||||
while e - s > 1:
|
||||
i = (e + s) // 2
|
||||
if target > arr[i]:
|
||||
s = i
|
||||
continue
|
||||
elif target < arr[i]:
|
||||
e = i
|
||||
continue
|
||||
else:
|
||||
assert False
|
||||
return s
|
||||
|
||||
cks = []
|
||||
readed = [False] * len(sections)
|
||||
levels = levels[::-1]
|
||||
for i, arr in enumerate(levels[:depth]):
|
||||
for j in arr:
|
||||
if readed[j]:
|
||||
continue
|
||||
readed[j] = True
|
||||
cks.append([j])
|
||||
if i + 1 == len(levels) - 1:
|
||||
continue
|
||||
for ii in range(i + 1, len(levels)):
|
||||
jj = binary_search(levels[ii], j)
|
||||
if jj < 0:
|
||||
continue
|
||||
if levels[ii][jj] > cks[-1][-1]:
|
||||
cks[-1].pop(-1)
|
||||
cks[-1].append(levels[ii][jj])
|
||||
for ii in cks[-1]:
|
||||
readed[ii] = True
|
||||
|
||||
if not cks:
|
||||
return cks
|
||||
|
||||
for i in range(len(cks)):
|
||||
cks[i] = [sections[j] for j in cks[i][::-1]]
|
||||
logging.debug("\n* ".join(cks[i]))
|
||||
|
||||
res = [[]]
|
||||
num = [0]
|
||||
for ck in cks:
|
||||
if len(ck) == 1:
|
||||
n = num_tokens_from_string(re.sub(r"@@[0-9]+.*", "", ck[0]))
|
||||
if n + num[-1] < 218:
|
||||
res[-1].append(ck[0])
|
||||
num[-1] += n
|
||||
continue
|
||||
res.append(ck)
|
||||
num.append(n)
|
||||
continue
|
||||
res.append(ck)
|
||||
num.append(218)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
|
||||
from app.core.rag.deepdoc.parser.pdf_parser import RAGPdfParser
|
||||
if not sections:
|
||||
return []
|
||||
if isinstance(sections, str):
|
||||
sections = [sections]
|
||||
if isinstance(sections[0], str):
|
||||
sections = [(s, "") for s in sections]
|
||||
cks = [""]
|
||||
tk_nums = [0]
|
||||
|
||||
def add_chunk(t, pos):
|
||||
nonlocal cks, tk_nums, delimiter
|
||||
tnum = num_tokens_from_string(t)
|
||||
if not pos:
|
||||
pos = ""
|
||||
if tnum < 8:
|
||||
pos = ""
|
||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
||||
if cks:
|
||||
overlapped = RAGPdfParser.remove_tag(cks[-1])
|
||||
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
||||
if t.find(pos) < 0:
|
||||
t += pos
|
||||
cks.append(t)
|
||||
tk_nums.append(tnum)
|
||||
else:
|
||||
if cks[-1].find(pos) < 0:
|
||||
t += pos
|
||||
cks[-1] += t
|
||||
tk_nums[-1] += tnum
|
||||
|
||||
dels = get_delimiters(delimiter)
|
||||
for sec, pos in sections:
|
||||
if num_tokens_from_string(sec) < chunk_token_num:
|
||||
add_chunk("\n"+sec, pos)
|
||||
continue
|
||||
split_sec = re.split(r"(%s)" % dels, sec, flags=re.DOTALL)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk("\n"+sub_sec, pos)
|
||||
|
||||
return cks
|
||||
|
||||
|
||||
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
|
||||
from app.core.rag.deepdoc.parser.pdf_parser import RAGPdfParser
|
||||
if not texts or len(texts) != len(images):
|
||||
return [], []
|
||||
cks = [""]
|
||||
result_images = [None]
|
||||
tk_nums = [0]
|
||||
|
||||
def add_chunk(t, image, pos=""):
|
||||
nonlocal cks, result_images, tk_nums, delimiter
|
||||
tnum = num_tokens_from_string(t)
|
||||
if not pos:
|
||||
pos = ""
|
||||
if tnum < 8:
|
||||
pos = ""
|
||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
||||
if cks:
|
||||
overlapped = RAGPdfParser.remove_tag(cks[-1])
|
||||
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
||||
if t.find(pos) < 0:
|
||||
t += pos
|
||||
cks.append(t)
|
||||
result_images.append(image)
|
||||
tk_nums.append(tnum)
|
||||
else:
|
||||
if cks[-1].find(pos) < 0:
|
||||
t += pos
|
||||
cks[-1] += t
|
||||
if result_images[-1] is None:
|
||||
result_images[-1] = image
|
||||
else:
|
||||
result_images[-1] = concat_img(result_images[-1], image)
|
||||
tk_nums[-1] += tnum
|
||||
|
||||
dels = get_delimiters(delimiter)
|
||||
for text, image in zip(texts, images):
|
||||
# if text is tuple, unpack it
|
||||
if isinstance(text, tuple):
|
||||
text_str = text[0]
|
||||
text_pos = text[1] if len(text) > 1 else ""
|
||||
split_sec = re.split(r"(%s)" % dels, text_str)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk("\n"+sub_sec, image, text_pos)
|
||||
else:
|
||||
split_sec = re.split(r"(%s)" % dels, text)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk("\n"+sub_sec, image)
|
||||
|
||||
return cks, result_images
|
||||
|
||||
def docx_question_level(p, bull=-1):
|
||||
txt = re.sub(r"\u3000", " ", p.text).strip()
|
||||
if p.style.name.startswith('Heading'):
|
||||
return int(p.style.name.split(' ')[-1]), txt
|
||||
else:
|
||||
if bull < 0:
|
||||
return 0, txt
|
||||
for j, title in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(title, txt):
|
||||
return j + 1, txt
|
||||
return len(BULLET_PATTERN[bull])+1, txt
|
||||
|
||||
|
||||
def concat_img(img1, img2):
|
||||
if img1 and not img2:
|
||||
return img1
|
||||
if not img1 and img2:
|
||||
return img2
|
||||
if not img1 and not img2:
|
||||
return None
|
||||
|
||||
if img1 is img2:
|
||||
return img1
|
||||
|
||||
if isinstance(img1, Image.Image) and isinstance(img2, Image.Image):
|
||||
pixel_data1 = img1.tobytes()
|
||||
pixel_data2 = img2.tobytes()
|
||||
if pixel_data1 == pixel_data2:
|
||||
return img1
|
||||
|
||||
width1, height1 = img1.size
|
||||
width2, height2 = img2.size
|
||||
|
||||
new_width = max(width1, width2)
|
||||
new_height = height1 + height2
|
||||
new_image = Image.new('RGB', (new_width, new_height))
|
||||
|
||||
new_image.paste(img1, (0, 0))
|
||||
new_image.paste(img2, (0, height1))
|
||||
return new_image
|
||||
|
||||
|
||||
def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
if not sections:
|
||||
return [], []
|
||||
|
||||
cks = [""]
|
||||
images = [None]
|
||||
tk_nums = [0]
|
||||
|
||||
def add_chunk(t, image, pos=""):
|
||||
nonlocal cks, tk_nums, delimiter
|
||||
tnum = num_tokens_from_string(t)
|
||||
if tnum < 8:
|
||||
pos = ""
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num:
|
||||
if t.find(pos) < 0:
|
||||
t += pos
|
||||
cks.append(t)
|
||||
images.append(image)
|
||||
tk_nums.append(tnum)
|
||||
else:
|
||||
if cks[-1].find(pos) < 0:
|
||||
t += pos
|
||||
cks[-1] += t
|
||||
images[-1] = concat_img(images[-1], image)
|
||||
tk_nums[-1] += tnum
|
||||
|
||||
dels = get_delimiters(delimiter)
|
||||
line = ""
|
||||
for sec, image in sections:
|
||||
if not image:
|
||||
line += sec + "\n"
|
||||
continue
|
||||
split_sec = re.split(r"(%s)" % dels, line + sec)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk("\n"+sub_sec, image,"")
|
||||
line = ""
|
||||
|
||||
if line:
|
||||
split_sec = re.split(r"(%s)" % dels, line)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk("\n"+sub_sec, image,"")
|
||||
|
||||
return cks, images
|
||||
|
||||
|
||||
def extract_between(text: str, start_tag: str, end_tag: str) -> list[str]:
|
||||
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
|
||||
return re.findall(pattern, text, flags=re.DOTALL)
|
||||
|
||||
|
||||
def get_delimiters(delimiters: str):
|
||||
dels = []
|
||||
s = 0
|
||||
for m in re.finditer(r"`([^`]+)`", delimiters, re.I):
|
||||
f, t = m.span()
|
||||
dels.append(m.group(1))
|
||||
dels.extend(list(delimiters[s: f]))
|
||||
s = t
|
||||
if s < len(delimiters):
|
||||
dels.extend(list(delimiters[s:]))
|
||||
|
||||
dels.sort(key=lambda x: -len(x))
|
||||
dels = [re.escape(d) for d in dels if d]
|
||||
dels = [d for d in dels if d]
|
||||
dels_pattern = "|".join(dels)
|
||||
|
||||
return dels_pattern
|
||||
|
||||
class Node:
|
||||
def __init__(self, level, depth=-1, texts=None):
|
||||
self.level = level
|
||||
self.depth = depth
|
||||
self.texts = texts or []
|
||||
self.children = []
|
||||
|
||||
def add_child(self, child_node):
|
||||
self.children.append(child_node)
|
||||
|
||||
def get_children(self):
|
||||
return self.children
|
||||
|
||||
def get_level(self):
|
||||
return self.level
|
||||
|
||||
def get_texts(self):
|
||||
return self.texts
|
||||
|
||||
def set_texts(self, texts):
|
||||
self.texts = texts
|
||||
|
||||
def add_text(self, text):
|
||||
self.texts.append(text)
|
||||
|
||||
def clear_text(self):
|
||||
self.texts = []
|
||||
|
||||
def __repr__(self):
|
||||
return f"Node(level={self.level}, texts={self.texts}, children={len(self.children)})"
|
||||
|
||||
def build_tree(self, lines):
|
||||
stack = [self]
|
||||
for level, text in lines:
|
||||
if self.depth != -1 and level > self.depth:
|
||||
# Beyond target depth: merge content into the current leaf instead of creating deeper nodes
|
||||
stack[-1].add_text(text)
|
||||
continue
|
||||
|
||||
# Move up until we find the proper parent whose level is strictly smaller than current
|
||||
while len(stack) > 1 and level <= stack[-1].get_level():
|
||||
stack.pop()
|
||||
|
||||
node = Node(level=level, texts=[text])
|
||||
# Attach as child of current parent and descend
|
||||
stack[-1].add_child(node)
|
||||
stack.append(node)
|
||||
|
||||
return self
|
||||
|
||||
def get_tree(self):
|
||||
tree_list = []
|
||||
self._dfs(self, tree_list, [])
|
||||
return tree_list
|
||||
|
||||
def _dfs(self, node, tree_list, titles):
|
||||
level = node.get_level()
|
||||
texts = node.get_texts()
|
||||
child = node.get_children()
|
||||
|
||||
if level == 0 and texts:
|
||||
tree_list.append("\n".join(titles+texts))
|
||||
|
||||
# Titles within configured depth are accumulated into the current path
|
||||
if 1 <= level <= self.depth:
|
||||
path_titles = titles + texts
|
||||
else:
|
||||
path_titles = titles
|
||||
|
||||
# Body outside the depth limit becomes its own chunk under the current title path
|
||||
if level > self.depth and texts:
|
||||
tree_list.append("\n".join(path_titles + texts))
|
||||
|
||||
# A leaf title within depth emits its title path as a chunk (header-only section)
|
||||
elif not child and (1 <= level <= self.depth):
|
||||
tree_list.append("\n".join(path_titles))
|
||||
|
||||
# Recurse into children with the updated title path
|
||||
for c in child:
|
||||
self._dfs(c, tree_list, path_titles)
|
||||
261
api/app/core/rag/nlp/query.py
Normal file
261
api/app/core/rag/nlp/query.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
from app.core.rag.utils.doc_store_conn import MatchTextExpr
|
||||
from . import rag_tokenizer, term_weight, synonym
|
||||
|
||||
|
||||
class FulltextQueryer:
|
||||
def __init__(self):
|
||||
self.tw = term_weight.Dealer()
|
||||
self.syn = synonym.Dealer()
|
||||
self.query_fields = [
|
||||
"title_tks^10",
|
||||
"title_sm_tks^5",
|
||||
"important_kwd^30",
|
||||
"important_tks^20",
|
||||
"question_tks^20",
|
||||
"content_ltks^2",
|
||||
"content_sm_ltks",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def subSpecialChar(line):
|
||||
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
|
||||
|
||||
@staticmethod
|
||||
def isChinese(line):
|
||||
arr = re.split(r"[ \t]+", line)
|
||||
if len(arr) <= 3:
|
||||
return True
|
||||
e = 0
|
||||
for t in arr:
|
||||
if not re.match(r"[a-zA-Z]+$", t):
|
||||
e += 1
|
||||
return e * 1.0 / len(arr) >= 0.7
|
||||
|
||||
@staticmethod
|
||||
def rmWWW(txt):
|
||||
patts = [
|
||||
(
|
||||
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
|
||||
"",
|
||||
),
|
||||
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
||||
(
|
||||
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
|
||||
" ")
|
||||
]
|
||||
otxt = txt
|
||||
for r, p in patts:
|
||||
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
||||
if not txt:
|
||||
txt = otxt
|
||||
return txt
|
||||
|
||||
@staticmethod
|
||||
def add_space_between_eng_zh(txt):
|
||||
# (ENG/ENG+NUM) + ZH
|
||||
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
|
||||
# ENG + ZH
|
||||
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
|
||||
# ZH + (ENG/ENG+NUM)
|
||||
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
|
||||
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
|
||||
return txt
|
||||
|
||||
def question(self, txt, tbl="qa", min_match: float = 0.6):
|
||||
txt = FulltextQueryer.add_space_between_eng_zh(txt)
|
||||
txt = re.sub(
|
||||
r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+",
|
||||
" ",
|
||||
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
|
||||
).strip()
|
||||
otxt = txt
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
|
||||
if not self.isChinese(txt):
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
tks = rag_tokenizer.tokenize(txt).split()
|
||||
keywords = [t for t in tks if t]
|
||||
tks_w = self.tw.weights(tks, preprocess=False)
|
||||
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
|
||||
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
|
||||
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
|
||||
tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()]
|
||||
syns = []
|
||||
for tk, w in tks_w[:256]:
|
||||
syn = self.syn.lookup(tk)
|
||||
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
|
||||
keywords.extend(syn)
|
||||
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
|
||||
syns.append(" ".join(syn))
|
||||
|
||||
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
|
||||
tk and not re.match(r"[.^+\(\)-]", tk)]
|
||||
for i in range(1, len(tks_w)):
|
||||
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
|
||||
if not left or not right:
|
||||
continue
|
||||
q.append(
|
||||
'"%s %s"^%.4f'
|
||||
% (
|
||||
tks_w[i - 1][0],
|
||||
tks_w[i][0],
|
||||
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
|
||||
)
|
||||
)
|
||||
if not q:
|
||||
q.append(txt)
|
||||
query = " ".join(q)
|
||||
return MatchTextExpr(
|
||||
self.query_fields, query, 100
|
||||
), keywords
|
||||
|
||||
def need_fine_grained_tokenize(tk):
|
||||
if len(tk) < 3:
|
||||
return False
|
||||
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
|
||||
return False
|
||||
return True
|
||||
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
qs, keywords = [], []
|
||||
for tt in self.tw.split(txt)[:256]: # .split():
|
||||
if not tt:
|
||||
continue
|
||||
keywords.append(tt)
|
||||
twts = self.tw.weights([tt])
|
||||
syns = self.syn.lookup(tt)
|
||||
if syns and len(keywords) < 32:
|
||||
keywords.extend(syns)
|
||||
logging.debug(json.dumps(twts, ensure_ascii=False))
|
||||
tms = []
|
||||
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
||||
sm = (
|
||||
rag_tokenizer.fine_grained_tokenize(tk).split()
|
||||
if need_fine_grained_tokenize(tk)
|
||||
else []
|
||||
)
|
||||
sm = [
|
||||
re.sub(
|
||||
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
|
||||
"",
|
||||
m,
|
||||
)
|
||||
for m in sm
|
||||
]
|
||||
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
||||
sm = [m for m in sm if len(m) > 1]
|
||||
|
||||
if len(keywords) < 32:
|
||||
keywords.append(re.sub(r"[ \\\"']+", "", tk))
|
||||
keywords.extend(sm)
|
||||
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||
if len(keywords) < 32:
|
||||
keywords.extend([s for s in tk_syns if s])
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
||||
|
||||
if len(keywords) >= 32:
|
||||
break
|
||||
|
||||
tk = FulltextQueryer.subSpecialChar(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
|
||||
if sm:
|
||||
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
|
||||
if tk.strip():
|
||||
tms.append((tk, w))
|
||||
|
||||
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
||||
|
||||
if len(twts) > 1:
|
||||
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
|
||||
|
||||
syns = " OR ".join(
|
||||
[
|
||||
'"%s"'
|
||||
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
|
||||
for s in syns
|
||||
]
|
||||
)
|
||||
if syns and tms:
|
||||
tms = f"({tms})^5 OR ({syns})^0.7"
|
||||
|
||||
qs.append(tms)
|
||||
|
||||
if qs:
|
||||
query = " OR ".join([f"({t})" for t in qs if t])
|
||||
if not query:
|
||||
query = otxt
|
||||
return MatchTextExpr(
|
||||
self.query_fields, query, 100, {"minimum_should_match": min_match}
|
||||
), keywords
|
||||
return None, keywords
|
||||
|
||||
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
|
||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||
import numpy as np
|
||||
|
||||
sims = CosineSimilarity([avec], bvecs)
|
||||
tksim = self.token_similarity(atks, btkss)
|
||||
if np.sum(sims[0]) == 0:
|
||||
return np.array(tksim), tksim, sims[0]
|
||||
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
|
||||
|
||||
def token_similarity(self, atks, btkss):
|
||||
def toDict(tks):
|
||||
if isinstance(tks, str):
|
||||
tks = tks.split()
|
||||
d = defaultdict(int)
|
||||
wts = self.tw.weights(tks, preprocess=False)
|
||||
for i, (t, c) in enumerate(wts):
|
||||
d[t] += c
|
||||
return d
|
||||
|
||||
atks = toDict(atks)
|
||||
btkss = [toDict(tks) for tks in btkss]
|
||||
return [self.similarity(atks, btks) for btks in btkss]
|
||||
|
||||
def similarity(self, qtwt, dtwt):
|
||||
if isinstance(dtwt, type("")):
|
||||
dtwt = {t: w for t, w in self.tw.weights(self.tw.split(dtwt), preprocess=False)}
|
||||
if isinstance(qtwt, type("")):
|
||||
qtwt = {t: w for t, w in self.tw.weights(self.tw.split(qtwt), preprocess=False)}
|
||||
s = 1e-9
|
||||
for k, v in qtwt.items():
|
||||
if k in dtwt:
|
||||
s += v #* dtwt[k]
|
||||
q = 1e-9
|
||||
for k, v in qtwt.items():
|
||||
q += v #* v
|
||||
return s/q #math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 )))
|
||||
|
||||
def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
|
||||
if isinstance(content_tks, str):
|
||||
content_tks = [c.strip() for c in content_tks.strip() if c.strip()]
|
||||
tks_w = self.tw.weights(content_tks, preprocess=False)
|
||||
|
||||
keywords = [f'"{k.strip()}"' for k in keywords]
|
||||
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
||||
tk = FulltextQueryer.subSpecialChar(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
|
||||
if tk:
|
||||
keywords.append(f"{tk}^{w}")
|
||||
|
||||
return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
|
||||
{"minimum_should_match": min(3, len(keywords) // 10)})
|
||||
499
api/app/core/rag/nlp/rag_tokenizer.py
Normal file
499
api/app/core/rag/nlp/rag_tokenizer.py
Normal file
@@ -0,0 +1,499 @@
|
||||
import logging
|
||||
import copy
|
||||
import datrie
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
from hanziconv import HanziConv
|
||||
from nltk import word_tokenize
|
||||
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
||||
from app.core.rag.common.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class RagTokenizer:
|
||||
def key_(self, line):
|
||||
return str(line.lower().encode("utf-8"))[2:-1]
|
||||
|
||||
def rkey_(self, line):
|
||||
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
|
||||
|
||||
def loadDict_(self, fnm):
|
||||
logging.info(f"[HUQIE]:Build trie from {fnm}")
|
||||
try:
|
||||
of = open(fnm, "r", encoding='utf-8')
|
||||
while True:
|
||||
line = of.readline()
|
||||
if not line:
|
||||
break
|
||||
line = re.sub(r"[\r\n]+", "", line)
|
||||
line = re.split(r"[ \t]", line)
|
||||
k = self.key_(line[0])
|
||||
F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
|
||||
if k not in self.trie_ or self.trie_[k][0] < F:
|
||||
self.trie_[self.key_(line[0])] = (F, line[2])
|
||||
self.trie_[self.rkey_(line[0])] = 1
|
||||
|
||||
trie_file_name = os.path.join(get_project_base_directory(), "res", "huqie") + ".txt.trie"
|
||||
logging.info(f"[HUQIE]:Build trie cache to {trie_file_name}")
|
||||
self.trie_.save(trie_file_name)
|
||||
of.close()
|
||||
except Exception:
|
||||
logging.exception(f"[HUQIE]:Build trie {fnm} failed")
|
||||
|
||||
def __init__(self, debug=False):
|
||||
self.DEBUG = debug
|
||||
self.DENOMINATOR = 1000000
|
||||
|
||||
self.stemmer = PorterStemmer()
|
||||
self.lemmatizer = WordNetLemmatizer()
|
||||
|
||||
self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)"
|
||||
|
||||
trie_file_name = os.path.join(get_project_base_directory(), "res", "huqie") + ".txt.trie"
|
||||
# check if trie file existence
|
||||
if os.path.exists(trie_file_name):
|
||||
try:
|
||||
# load trie from file
|
||||
self.trie_ = datrie.Trie.load(trie_file_name)
|
||||
return
|
||||
except Exception:
|
||||
# fail to load trie from file, build default trie
|
||||
logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
else:
|
||||
# file not exist, build default trie
|
||||
logging.info(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
|
||||
# load data from dict file and save to trie file
|
||||
self.loadDict_(os.path.join(get_project_base_directory(), "app/core/rag/res", "huqie") + ".txt")
|
||||
|
||||
def loadUserDict(self, fnm):
|
||||
try:
|
||||
self.trie_ = datrie.Trie.load(fnm + ".trie")
|
||||
return
|
||||
except Exception:
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
self.loadDict_(fnm)
|
||||
|
||||
def addUserDict(self, fnm):
|
||||
self.loadDict_(fnm)
|
||||
|
||||
def _strQ2B(self, ustring):
|
||||
"""Convert full-width characters to half-width characters"""
|
||||
rstring = ""
|
||||
for uchar in ustring:
|
||||
inside_code = ord(uchar)
|
||||
if inside_code == 0x3000:
|
||||
inside_code = 0x0020
|
||||
else:
|
||||
inside_code -= 0xfee0
|
||||
if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character.
|
||||
rstring += uchar
|
||||
else:
|
||||
rstring += chr(inside_code)
|
||||
return rstring
|
||||
|
||||
def _tradi2simp(self, line):
|
||||
return HanziConv.toSimplified(line)
|
||||
|
||||
def dfs_(self, chars, s, preTks, tkslist, _depth=0, _memo=None):
|
||||
if _memo is None:
|
||||
_memo = {}
|
||||
MAX_DEPTH = 10
|
||||
if _depth > MAX_DEPTH:
|
||||
if s < len(chars):
|
||||
copy_pretks = copy.deepcopy(preTks)
|
||||
remaining = "".join(chars[s:])
|
||||
copy_pretks.append((remaining, (-12, '')))
|
||||
tkslist.append(copy_pretks)
|
||||
return s
|
||||
|
||||
state_key = (s, tuple(tk[0] for tk in preTks)) if preTks else (s, None)
|
||||
if state_key in _memo:
|
||||
return _memo[state_key]
|
||||
|
||||
res = s
|
||||
if s >= len(chars):
|
||||
tkslist.append(preTks)
|
||||
_memo[state_key] = s
|
||||
return s
|
||||
if s < len(chars) - 4:
|
||||
is_repetitive = True
|
||||
char_to_check = chars[s]
|
||||
for i in range(1, 5):
|
||||
if s + i >= len(chars) or chars[s + i] != char_to_check:
|
||||
is_repetitive = False
|
||||
break
|
||||
if is_repetitive:
|
||||
end = s
|
||||
while end < len(chars) and chars[end] == char_to_check:
|
||||
end += 1
|
||||
mid = s + min(10, end - s)
|
||||
t = "".join(chars[s:mid])
|
||||
k = self.key_(t)
|
||||
copy_pretks = copy.deepcopy(preTks)
|
||||
if k in self.trie_:
|
||||
copy_pretks.append((t, self.trie_[k]))
|
||||
else:
|
||||
copy_pretks.append((t, (-12, '')))
|
||||
next_res = self.dfs_(chars, mid, copy_pretks, tkslist, _depth + 1, _memo)
|
||||
res = max(res, next_res)
|
||||
_memo[state_key] = res
|
||||
return res
|
||||
|
||||
S = s + 1
|
||||
if s + 2 <= len(chars):
|
||||
t1 = "".join(chars[s:s + 1])
|
||||
t2 = "".join(chars[s:s + 2])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
|
||||
S = s + 2
|
||||
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
|
||||
t1 = preTks[-1][0] + "".join(chars[s:s + 1])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)):
|
||||
S = s + 2
|
||||
|
||||
for e in range(S, len(chars) + 1):
|
||||
t = "".join(chars[s:e])
|
||||
k = self.key_(t)
|
||||
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
|
||||
break
|
||||
if k in self.trie_:
|
||||
pretks = copy.deepcopy(preTks)
|
||||
pretks.append((t, self.trie_[k]))
|
||||
res = max(res, self.dfs_(chars, e, pretks, tkslist, _depth + 1, _memo))
|
||||
|
||||
if res > s:
|
||||
_memo[state_key] = res
|
||||
return res
|
||||
|
||||
t = "".join(chars[s:s + 1])
|
||||
k = self.key_(t)
|
||||
copy_pretks = copy.deepcopy(preTks)
|
||||
if k in self.trie_:
|
||||
copy_pretks.append((t, self.trie_[k]))
|
||||
else:
|
||||
copy_pretks.append((t, (-12, '')))
|
||||
result = self.dfs_(chars, s + 1, copy_pretks, tkslist, _depth + 1, _memo)
|
||||
_memo[state_key] = result
|
||||
return result
|
||||
|
||||
def freq(self, tk):
|
||||
k = self.key_(tk)
|
||||
if k not in self.trie_:
|
||||
return 0
|
||||
return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
|
||||
|
||||
def tag(self, tk):
|
||||
k = self.key_(tk)
|
||||
if k not in self.trie_:
|
||||
return ""
|
||||
return self.trie_[k][1]
|
||||
|
||||
def score_(self, tfts):
|
||||
B = 30
|
||||
F, L, tks = 0, 0, []
|
||||
for tk, (freq, tag) in tfts:
|
||||
F += freq
|
||||
L += 0 if len(tk) < 2 else 1
|
||||
tks.append(tk)
|
||||
#F /= len(tks)
|
||||
L /= len(tks)
|
||||
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
|
||||
return tks, B / len(tks) + L + F
|
||||
|
||||
def sortTks_(self, tkslist):
|
||||
res = []
|
||||
for tfts in tkslist:
|
||||
tks, s = self.score_(tfts)
|
||||
res.append((tks, s))
|
||||
return sorted(res, key=lambda x: x[1], reverse=True)
|
||||
|
||||
def merge_(self, tks):
|
||||
# if split chars is part of token
|
||||
res = []
|
||||
tks = re.sub(r"[ ]+", " ", tks).split()
|
||||
s = 0
|
||||
while True:
|
||||
if s >= len(tks):
|
||||
break
|
||||
E = s + 1
|
||||
for e in range(s + 2, min(len(tks) + 2, s + 6)):
|
||||
tk = "".join(tks[s:e])
|
||||
if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
|
||||
E = e
|
||||
res.append("".join(tks[s:E]))
|
||||
s = E
|
||||
|
||||
return " ".join(res)
|
||||
|
||||
def maxForward_(self, line):
|
||||
res = []
|
||||
s = 0
|
||||
while s < len(line):
|
||||
e = s + 1
|
||||
t = line[s:e]
|
||||
while e < len(line) and self.trie_.has_keys_with_prefix(
|
||||
self.key_(t)):
|
||||
e += 1
|
||||
t = line[s:e]
|
||||
|
||||
while e - 1 > s and self.key_(t) not in self.trie_:
|
||||
e -= 1
|
||||
t = line[s:e]
|
||||
|
||||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, '')))
|
||||
|
||||
s = e
|
||||
|
||||
return self.score_(res)
|
||||
|
||||
def maxBackward_(self, line):
|
||||
res = []
|
||||
s = len(line) - 1
|
||||
while s >= 0:
|
||||
e = s + 1
|
||||
t = line[s:e]
|
||||
while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
|
||||
s -= 1
|
||||
t = line[s:e]
|
||||
|
||||
while s + 1 < e and self.key_(t) not in self.trie_:
|
||||
s += 1
|
||||
t = line[s:e]
|
||||
|
||||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, '')))
|
||||
|
||||
s -= 1
|
||||
|
||||
return self.score_(res[::-1])
|
||||
|
||||
def english_normalize_(self, tks):
|
||||
return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
|
||||
|
||||
def _split_by_lang(self, line):
|
||||
txt_lang_pairs = []
|
||||
arr = re.split(self.SPLIT_CHAR, line)
|
||||
for a in arr:
|
||||
if not a:
|
||||
continue
|
||||
s = 0
|
||||
e = s + 1
|
||||
zh = is_chinese(a[s])
|
||||
while e < len(a):
|
||||
_zh = is_chinese(a[e])
|
||||
if _zh == zh:
|
||||
e += 1
|
||||
continue
|
||||
txt_lang_pairs.append((a[s: e], zh))
|
||||
s = e
|
||||
e = s + 1
|
||||
zh = _zh
|
||||
if s >= len(a):
|
||||
continue
|
||||
txt_lang_pairs.append((a[s: e], zh))
|
||||
return txt_lang_pairs
|
||||
|
||||
def tokenize(self, line):
|
||||
line = re.sub(r"\W+", " ", line)
|
||||
line = self._strQ2B(line).lower()
|
||||
line = self._tradi2simp(line)
|
||||
|
||||
arr = self._split_by_lang(line)
|
||||
res = []
|
||||
for L,lang in arr:
|
||||
if not lang:
|
||||
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
|
||||
continue
|
||||
if len(L) < 2 or re.match(
|
||||
r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
|
||||
res.append(L)
|
||||
continue
|
||||
|
||||
# use maxforward for the first time
|
||||
tks, s = self.maxForward_(L)
|
||||
tks1, s1 = self.maxBackward_(L)
|
||||
if self.DEBUG:
|
||||
logging.debug("[FW] {} {}".format(tks, s))
|
||||
logging.debug("[BW] {} {}".format(tks1, s1))
|
||||
|
||||
i, j, _i, _j = 0, 0, 0, 0
|
||||
same = 0
|
||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
same += 1
|
||||
if same > 0:
|
||||
res.append(" ".join(tks[j: j + same]))
|
||||
_i = i + same
|
||||
_j = j + same
|
||||
j = _j + 1
|
||||
i = _i + 1
|
||||
|
||||
while i < len(tks1) and j < len(tks):
|
||||
tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
|
||||
if tk1 != tk:
|
||||
if len(tk1) > len(tk):
|
||||
j += 1
|
||||
else:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if tks1[i] != tks[j]:
|
||||
i += 1
|
||||
j += 1
|
||||
continue
|
||||
# backward tokens from_i to i are different from forward tokens from _j to j.
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||
|
||||
same = 1
|
||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
same += 1
|
||||
res.append(" ".join(tks[j: j + same]))
|
||||
_i = i + same
|
||||
_j = j + same
|
||||
j = _j + 1
|
||||
i = _i + 1
|
||||
|
||||
if _i < len(tks1):
|
||||
assert _j < len(tks)
|
||||
assert "".join(tks1[_i:]) == "".join(tks[_j:])
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||
|
||||
res = " ".join(res)
|
||||
logging.debug("[TKS] {}".format(self.merge_(res)))
|
||||
return self.merge_(res)
|
||||
|
||||
def fine_grained_tokenize(self, tks):
|
||||
tks = tks.split()
|
||||
zh_num = len([1 for c in tks if c and is_chinese(c[0])])
|
||||
if zh_num < len(tks) * 0.2:
|
||||
res = []
|
||||
for tk in tks:
|
||||
res.extend(tk.split("/"))
|
||||
return " ".join(res)
|
||||
|
||||
res = []
|
||||
for tk in tks:
|
||||
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
|
||||
res.append(tk)
|
||||
continue
|
||||
tkslist = []
|
||||
if len(tk) > 10:
|
||||
tkslist.append(tk)
|
||||
else:
|
||||
self.dfs_(tk, 0, [], tkslist)
|
||||
if len(tkslist) < 2:
|
||||
res.append(tk)
|
||||
continue
|
||||
stk = self.sortTks_(tkslist)[1][0]
|
||||
if len(stk) == len(tk):
|
||||
stk = tk
|
||||
else:
|
||||
if re.match(r"[a-z\.-]+$", tk):
|
||||
for t in stk:
|
||||
if len(t) < 3:
|
||||
stk = tk
|
||||
break
|
||||
else:
|
||||
stk = " ".join(stk)
|
||||
else:
|
||||
stk = " ".join(stk)
|
||||
|
||||
res.append(stk)
|
||||
|
||||
return " ".join(self.english_normalize_(res))
|
||||
|
||||
|
||||
def is_chinese(s):
|
||||
if s >= u'\u4e00' and s <= u'\u9fa5':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_number(s):
|
||||
if s >= u'\u0030' and s <= u'\u0039':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_alphabet(s):
|
||||
if (s >= u'\u0041' and s <= u'\u005a') or (
|
||||
s >= u'\u0061' and s <= u'\u007a'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def naiveQie(txt):
|
||||
tks = []
|
||||
for t in txt.split():
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
|
||||
) and re.match(r".*[a-zA-Z]$", t):
|
||||
tks.append(" ")
|
||||
tks.append(t)
|
||||
return tks
|
||||
|
||||
|
||||
tokenizer = RagTokenizer()
|
||||
tokenize = tokenizer.tokenize
|
||||
fine_grained_tokenize = tokenizer.fine_grained_tokenize
|
||||
tag = tokenizer.tag
|
||||
freq = tokenizer.freq
|
||||
loadUserDict = tokenizer.loadUserDict
|
||||
addUserDict = tokenizer.addUserDict
|
||||
tradi2simp = tokenizer._tradi2simp
|
||||
strQ2B = tokenizer._strQ2B
|
||||
|
||||
if __name__ == '__main__':
|
||||
tknzr = RagTokenizer(debug=True)
|
||||
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
|
||||
tks = tknzr.tokenize(
|
||||
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("虽然我不怎么玩")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
if len(sys.argv) < 2:
|
||||
sys.exit()
|
||||
tknzr.DEBUG = False
|
||||
tknzr.loadUserDict(sys.argv[1])
|
||||
of = open(sys.argv[2], "r")
|
||||
while True:
|
||||
line = of.readline()
|
||||
if not line:
|
||||
break
|
||||
logging.info(tknzr.tokenize(line))
|
||||
of.close()
|
||||
192
api/app/core/rag/nlp/search.py
Normal file
192
api/app/core/rag/nlp/search.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import uuid
|
||||
from typing import Dict, List, Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from app.db import get_db
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models import RedBearLLM, RedBearRerank
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.models import knowledge_model
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.repositories import knowledge_repository, knowledgeshare_repository
|
||||
from app.services.model_service import ModelConfigService
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
|
||||
|
||||
def knowledge_retrieval(
|
||||
query: str,
|
||||
config: Dict[str, Any],
|
||||
user_ids: List[str] = None,
|
||||
) -> list[DocumentChunk]:
|
||||
"""
|
||||
Knowledge retrieval with multiple knowledge bases and reranking
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
config: Configuration dictionary containing:
|
||||
- knowledge_bases: List of knowledge base configs with:
|
||||
- kb_id: Knowledge base ID
|
||||
- similarity_threshold: float
|
||||
- vector_similarity_weight: float
|
||||
- top_k: int
|
||||
- retrieve_type: "participle" or "semantic" or "hybrid"
|
||||
- merge_strategy: "weight" or other strategies
|
||||
- reranker_id: UUID of the reranker to use
|
||||
- reranker_top_k: int
|
||||
|
||||
Returns:
|
||||
Rearranged document block list (in descending order of relevance)
|
||||
"""
|
||||
db = next(get_db()) # Manually call the generator
|
||||
try:
|
||||
# parse configuration
|
||||
knowledge_bases = config.get("knowledge_bases", [])
|
||||
merge_strategy = config.get("merge_strategy", "weight")
|
||||
reranker_id = config.get("reranker_id")
|
||||
reranker_top_k = config.get("reranker_top_k", 1024)
|
||||
|
||||
file_names_filter=[]
|
||||
if user_ids:
|
||||
file_names_filter.extend([f"{user_id}.txt" for user_id in user_ids])
|
||||
|
||||
if not knowledge_bases:
|
||||
return []
|
||||
|
||||
all_results = []
|
||||
# Search each knowledge base
|
||||
for kb_config in knowledge_bases:
|
||||
kb_id = kb_config["kb_id"]
|
||||
try:
|
||||
# Check whether the knowledge base exists and is available
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id)
|
||||
if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1:
|
||||
# Process shared knowledge base
|
||||
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
|
||||
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db,
|
||||
knowledgeshare_id=db_knowledge.id)
|
||||
if knowledgeshare:
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db,
|
||||
knowledge_id=knowledgeshare.source_kb_id)
|
||||
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
# Retrieve according to the configured retrieval type
|
||||
match kb_config["retrieve_type"]:
|
||||
case "participle":
|
||||
rs = vector_service.search_by_full_text(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["similarity_threshold"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
case "semantic":
|
||||
rs = vector_service.search_by_vector(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["vector_similarity_weight"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
case _: # hybrid
|
||||
rs1 = vector_service.search_by_vector(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["vector_similarity_weight"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
rs2 = vector_service.search_by_full_text(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["similarity_threshold"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
|
||||
# Deduplication of merge results
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
for doc in rs1 + rs2:
|
||||
if doc.metadata["doc_id"] not in seen_ids:
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = unique_rs
|
||||
|
||||
all_results.extend(rs)
|
||||
except Exception as e:
|
||||
# Failure of retrieval in a single knowledge base does not affect other knowledge bases
|
||||
print(f"retrieval knowledge({kb_id}) failed: {str(e)}")
|
||||
continue
|
||||
|
||||
# Use the specified reranker for re-ranking
|
||||
if reranker_id:
|
||||
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||
return all_results
|
||||
|
||||
except Exception as e:
|
||||
print(f"retrieval knowledge failed: {str(e)}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
||||
"""
|
||||
Reorder the list of document blocks and return the top_k results most relevant to the query
|
||||
Args:
|
||||
reranker_id: reranker model id
|
||||
query: query string
|
||||
docs: List of document blocks to be rearranged
|
||||
top_k: Number of top-level documents returned
|
||||
|
||||
Returns:
|
||||
Rearranged document block list (in descending order of relevance)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input document list is empty or top_k is invalid
|
||||
"""
|
||||
# 参数校验
|
||||
if not reranker_id:
|
||||
raise ValueError("reranker_id be empty")
|
||||
if not docs:
|
||||
raise ValueError("retrieval chunks be empty")
|
||||
if top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
try:
|
||||
# initialize reranker
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=reranker_id)
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
reranker = RedBearRerank(RedBearModelConfig(
|
||||
model_name=apiConfig.model_name,
|
||||
provider=apiConfig.provider,
|
||||
api_key=apiConfig.api_key,
|
||||
base_url=apiConfig.api_base
|
||||
))
|
||||
# Convert to LangChain Document object
|
||||
documents = [
|
||||
Document(
|
||||
page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
|
||||
metadata=doc.metadata or {} # Deal with possible None metadata
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
|
||||
# Perform reordering (compress_documents will automatically handle relevance scores and indexing)
|
||||
reranked_docs = list(reranker.compress_documents(documents, query))
|
||||
print(reranked_docs)
|
||||
|
||||
# Sort in descending order based on relevance score
|
||||
reranked_docs.sort(
|
||||
key=lambda x: x.metadata.get("relevance_score", 0),
|
||||
reverse=True
|
||||
)
|
||||
# Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
|
||||
result = []
|
||||
for item in reranked_docs[:top_k]:
|
||||
for doc in docs:
|
||||
if doc.page_content == item.page_content:
|
||||
doc.metadata["score"] = item.metadata["relevance_score"]
|
||||
result.append(doc)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
|
||||
126
api/app/core/rag/nlp/surname.py
Normal file
126
api/app/core/rag/nlp/surname.py
Normal file
@@ -0,0 +1,126 @@
|
||||
m = set(["赵","钱","孙","李",
|
||||
"周","吴","郑","王",
|
||||
"冯","陈","褚","卫",
|
||||
"蒋","沈","韩","杨",
|
||||
"朱","秦","尤","许",
|
||||
"何","吕","施","张",
|
||||
"孔","曹","严","华",
|
||||
"金","魏","陶","姜",
|
||||
"戚","谢","邹","喻",
|
||||
"柏","水","窦","章",
|
||||
"云","苏","潘","葛",
|
||||
"奚","范","彭","郎",
|
||||
"鲁","韦","昌","马",
|
||||
"苗","凤","花","方",
|
||||
"俞","任","袁","柳",
|
||||
"酆","鲍","史","唐",
|
||||
"费","廉","岑","薛",
|
||||
"雷","贺","倪","汤",
|
||||
"滕","殷","罗","毕",
|
||||
"郝","邬","安","常",
|
||||
"乐","于","时","傅",
|
||||
"皮","卞","齐","康",
|
||||
"伍","余","元","卜",
|
||||
"顾","孟","平","黄",
|
||||
"和","穆","萧","尹",
|
||||
"姚","邵","湛","汪",
|
||||
"祁","毛","禹","狄",
|
||||
"米","贝","明","臧",
|
||||
"计","伏","成","戴",
|
||||
"谈","宋","茅","庞",
|
||||
"熊","纪","舒","屈",
|
||||
"项","祝","董","梁",
|
||||
"杜","阮","蓝","闵",
|
||||
"席","季","麻","强",
|
||||
"贾","路","娄","危",
|
||||
"江","童","颜","郭",
|
||||
"梅","盛","林","刁",
|
||||
"钟","徐","邱","骆",
|
||||
"高","夏","蔡","田",
|
||||
"樊","胡","凌","霍",
|
||||
"虞","万","支","柯",
|
||||
"昝","管","卢","莫",
|
||||
"经","房","裘","缪",
|
||||
"干","解","应","宗",
|
||||
"丁","宣","贲","邓",
|
||||
"郁","单","杭","洪",
|
||||
"包","诸","左","石",
|
||||
"崔","吉","钮","龚",
|
||||
"程","嵇","邢","滑",
|
||||
"裴","陆","荣","翁",
|
||||
"荀","羊","於","惠",
|
||||
"甄","曲","家","封",
|
||||
"芮","羿","储","靳",
|
||||
"汲","邴","糜","松",
|
||||
"井","段","富","巫",
|
||||
"乌","焦","巴","弓",
|
||||
"牧","隗","山","谷",
|
||||
"车","侯","宓","蓬",
|
||||
"全","郗","班","仰",
|
||||
"秋","仲","伊","宫",
|
||||
"宁","仇","栾","暴",
|
||||
"甘","钭","厉","戎",
|
||||
"祖","武","符","刘",
|
||||
"景","詹","束","龙",
|
||||
"叶","幸","司","韶",
|
||||
"郜","黎","蓟","薄",
|
||||
"印","宿","白","怀",
|
||||
"蒲","邰","从","鄂",
|
||||
"索","咸","籍","赖",
|
||||
"卓","蔺","屠","蒙",
|
||||
"池","乔","阴","鬱",
|
||||
"胥","能","苍","双",
|
||||
"闻","莘","党","翟",
|
||||
"谭","贡","劳","逄",
|
||||
"姬","申","扶","堵",
|
||||
"冉","宰","郦","雍",
|
||||
"郤","璩","桑","桂",
|
||||
"濮","牛","寿","通",
|
||||
"边","扈","燕","冀",
|
||||
"郏","浦","尚","农",
|
||||
"温","别","庄","晏",
|
||||
"柴","瞿","阎","充",
|
||||
"慕","连","茹","习",
|
||||
"宦","艾","鱼","容",
|
||||
"向","古","易","慎",
|
||||
"戈","廖","庾","终",
|
||||
"暨","居","衡","步",
|
||||
"都","耿","满","弘",
|
||||
"匡","国","文","寇",
|
||||
"广","禄","阙","东",
|
||||
"欧","殳","沃","利",
|
||||
"蔚","越","夔","隆",
|
||||
"师","巩","厍","聂",
|
||||
"晁","勾","敖","融",
|
||||
"冷","訾","辛","阚",
|
||||
"那","简","饶","空",
|
||||
"曾","母","沙","乜",
|
||||
"养","鞠","须","丰",
|
||||
"巢","关","蒯","相",
|
||||
"查","后","荆","红",
|
||||
"游","竺","权","逯",
|
||||
"盖","益","桓","公",
|
||||
"兰","原","乞","西","阿","肖","丑","位","曽","巨","德","代","圆","尉","仵","纳","仝","脱","丘","但","展","迪","付","覃","晗","特","隋","苑","奥","漆","谌","郄","练","扎","邝","渠","信","门","陳","化","原","密","泮","鹿","赫",
|
||||
"万俟","司马","上官","欧阳",
|
||||
"夏侯","诸葛","闻人","东方",
|
||||
"赫连","皇甫","尉迟","公羊",
|
||||
"澹台","公冶","宗政","濮阳",
|
||||
"淳于","单于","太叔","申屠",
|
||||
"公孙","仲孙","轩辕","令狐",
|
||||
"钟离","宇文","长孙","慕容",
|
||||
"鲜于","闾丘","司徒","司空",
|
||||
"亓官","司寇","仉督","子车",
|
||||
"颛孙","端木","巫马","公西",
|
||||
"漆雕","乐正","壤驷","公良",
|
||||
"拓跋","夹谷","宰父","榖梁",
|
||||
"晋","楚","闫","法","汝","鄢","涂","钦",
|
||||
"段干","百里","东郭","南门",
|
||||
"呼延","归","海","羊舌","微","生",
|
||||
"岳","帅","缑","亢","况","后","有","琴",
|
||||
"梁丘","左丘","东门","西门",
|
||||
"商","牟","佘","佴","伯","赏","南宫",
|
||||
"墨","哈","谯","笪","年","爱","阳","佟",
|
||||
"第五","言","福"])
|
||||
|
||||
def isit(n):return n.strip() in m
|
||||
|
||||
85
api/app/core/rag/nlp/synonym.py
Normal file
85
api/app/core/rag/nlp/synonym.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
from nltk.corpus import wordnet
|
||||
from app.core.rag.common.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self, redis=None):
|
||||
|
||||
self.lookup_num = 100000000
|
||||
self.load_tm = time.time() - 1000000
|
||||
self.dictionary = None
|
||||
path = os.path.join(get_project_base_directory(), "app/core/rag/res", "synonym.json")
|
||||
try:
|
||||
self.dictionary = json.load(open(path, 'r'))
|
||||
self.dictionary = { (k.lower() if isinstance(k, str) else k): v for k, v in self.dictionary.items() }
|
||||
except Exception:
|
||||
logging.warning("Missing synonym.json")
|
||||
self.dictionary = {}
|
||||
|
||||
if not redis:
|
||||
logging.warning(
|
||||
"Realtime synonym is disabled, since no redis connection.")
|
||||
if not len(self.dictionary.keys()):
|
||||
logging.warning("Fail to load synonym")
|
||||
|
||||
self.redis = redis
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
if not self.redis:
|
||||
return
|
||||
|
||||
if self.lookup_num < 100:
|
||||
return
|
||||
tm = time.time()
|
||||
if tm - self.load_tm < 3600:
|
||||
return
|
||||
|
||||
self.load_tm = time.time()
|
||||
self.lookup_num = 0
|
||||
d = self.redis.get("kevin_synonyms")
|
||||
if not d:
|
||||
return
|
||||
try:
|
||||
d = json.loads(d)
|
||||
self.dictionary = d
|
||||
except Exception as e:
|
||||
logging.error("Fail to load synonym!" + str(e))
|
||||
|
||||
|
||||
def lookup(self, tk, topn=8):
|
||||
if not tk or not isinstance(tk, str):
|
||||
return []
|
||||
|
||||
# 1) Check the custom dictionary first (both keys and tk are already lowercase)
|
||||
self.lookup_num += 1
|
||||
self.load()
|
||||
key = re.sub(r"[ \t]+", " ", tk.strip())
|
||||
res = self.dictionary.get(key, [])
|
||||
if isinstance(res, str):
|
||||
res = [res]
|
||||
if res: # Found in dictionary → return directly
|
||||
return res[:topn]
|
||||
|
||||
# 2) If not found and tk is purely alphabetical → fallback to WordNet
|
||||
if re.fullmatch(r"[a-z]+", tk):
|
||||
wn_set = {
|
||||
re.sub("_", " ", syn.name().split(".")[0])
|
||||
for syn in wordnet.synsets(tk)
|
||||
}
|
||||
wn_set.discard(tk) # Remove the original token itself
|
||||
wn_res = [t for t in wn_set if t]
|
||||
return wn_res[:topn]
|
||||
|
||||
# 3) Nothing found in either source
|
||||
return []
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dl = Dealer()
|
||||
print(dl.dictionary)
|
||||
228
api/app/core/rag/nlp/term_weight.py
Normal file
228
api/app/core/rag/nlp/term_weight.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import logging
|
||||
import math
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import numpy as np
|
||||
from . import rag_tokenizer
|
||||
from app.core.rag.common.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self):
|
||||
self.stop_words = set(["请问",
|
||||
"您",
|
||||
"你",
|
||||
"我",
|
||||
"他",
|
||||
"是",
|
||||
"的",
|
||||
"就",
|
||||
"有",
|
||||
"于",
|
||||
"及",
|
||||
"即",
|
||||
"在",
|
||||
"为",
|
||||
"最",
|
||||
"有",
|
||||
"从",
|
||||
"以",
|
||||
"了",
|
||||
"将",
|
||||
"与",
|
||||
"吗",
|
||||
"吧",
|
||||
"中",
|
||||
"#",
|
||||
"什么",
|
||||
"怎么",
|
||||
"哪个",
|
||||
"哪些",
|
||||
"啥",
|
||||
"相关"])
|
||||
|
||||
def load_dict(fnm):
|
||||
res = {}
|
||||
f = open(fnm, "r")
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
arr = line.replace("\n", "").split("\t")
|
||||
if len(arr) < 2:
|
||||
res[arr[0]] = 0
|
||||
else:
|
||||
res[arr[0]] = int(arr[1])
|
||||
|
||||
c = 0
|
||||
for _, v in res.items():
|
||||
c += v
|
||||
if c == 0:
|
||||
return set(res.keys())
|
||||
return res
|
||||
|
||||
fnm = os.path.join(get_project_base_directory(), "app/core/rag/res")
|
||||
self.ne, self.df = {}, {}
|
||||
try:
|
||||
self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
|
||||
except Exception:
|
||||
logging.warning("Load ner.json FAIL!")
|
||||
try:
|
||||
self.df = load_dict(os.path.join(fnm, "term.freq"))
|
||||
except Exception:
|
||||
logging.warning("Load term.freq FAIL!")
|
||||
|
||||
def pretoken(self, txt, num=False, stpwd=True):
|
||||
patt = [
|
||||
r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"
|
||||
]
|
||||
rewt = [
|
||||
]
|
||||
for p, r in rewt:
|
||||
txt = re.sub(p, r, txt)
|
||||
|
||||
res = []
|
||||
for t in rag_tokenizer.tokenize(txt).split():
|
||||
tk = t
|
||||
if (stpwd and tk in self.stop_words) or (
|
||||
re.match(r"[0-9]$", tk) and not num):
|
||||
continue
|
||||
for p in patt:
|
||||
if re.match(p, t):
|
||||
tk = "#"
|
||||
break
|
||||
#tk = re.sub(r"([\+\\-])", r"\\\1", tk)
|
||||
if tk != "#" and tk:
|
||||
res.append(tk)
|
||||
return res
|
||||
|
||||
def tokenMerge(self, tks):
|
||||
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||
|
||||
res, i = [], 0
|
||||
while i < len(tks):
|
||||
j = i
|
||||
if i == 0 and oneTerm(tks[i]) and len(
|
||||
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
|
||||
res.append(" ".join(tks[0:2]))
|
||||
i = 2
|
||||
continue
|
||||
|
||||
while j < len(
|
||||
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
|
||||
j += 1
|
||||
if j - i > 1:
|
||||
if j - i < 5:
|
||||
res.append(" ".join(tks[i:j]))
|
||||
i = j
|
||||
else:
|
||||
res.append(" ".join(tks[i:i + 2]))
|
||||
i = i + 2
|
||||
else:
|
||||
if len(tks[i]) > 0:
|
||||
res.append(tks[i])
|
||||
i += 1
|
||||
return [t for t in res if t]
|
||||
|
||||
def ner(self, t):
|
||||
if not self.ne:
|
||||
return ""
|
||||
res = self.ne.get(t, "")
|
||||
if res:
|
||||
return res
|
||||
|
||||
def split(self, txt):
|
||||
tks = []
|
||||
for t in re.sub(r"[ \t]+", " ", txt).split():
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
|
||||
re.match(r".*[a-zA-Z]$", t) and tks and \
|
||||
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
|
||||
tks[-1] = tks[-1] + " " + t
|
||||
else:
|
||||
tks.append(t)
|
||||
return tks
|
||||
|
||||
def weights(self, tks, preprocess=True):
|
||||
num_pattern = re.compile(r"[0-9,.]{2,}$")
|
||||
short_letter_pattern = re.compile(r"[a-z]{1,2}$")
|
||||
num_space_pattern = re.compile(r"[0-9. -]{2,}$")
|
||||
letter_pattern = re.compile(r"[a-z. -]+$")
|
||||
|
||||
def ner(t):
|
||||
if num_pattern.match(t):
|
||||
return 2
|
||||
if short_letter_pattern.match(t):
|
||||
return 0.01
|
||||
if not self.ne or t not in self.ne:
|
||||
return 1
|
||||
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
|
||||
"firstnm": 1}
|
||||
return m[self.ne[t]]
|
||||
|
||||
def postag(t):
|
||||
t = rag_tokenizer.tag(t)
|
||||
if t in set(["r", "c", "d"]):
|
||||
return 0.3
|
||||
if t in set(["ns", "nt"]):
|
||||
return 3
|
||||
if t in set(["n"]):
|
||||
return 2
|
||||
if re.match(r"[0-9-]+", t):
|
||||
return 2
|
||||
return 1
|
||||
|
||||
def freq(t):
|
||||
if num_space_pattern.match(t):
|
||||
return 3
|
||||
s = rag_tokenizer.freq(t)
|
||||
if not s and letter_pattern.match(t):
|
||||
return 300
|
||||
if not s:
|
||||
s = 0
|
||||
|
||||
if not s and len(t) >= 4:
|
||||
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
|
||||
if len(s) > 1:
|
||||
s = np.min([freq(tt) for tt in s]) / 6.
|
||||
else:
|
||||
s = 0
|
||||
|
||||
return max(s, 10)
|
||||
|
||||
def df(t):
|
||||
if num_space_pattern.match(t):
|
||||
return 5
|
||||
if t in self.df:
|
||||
return self.df[t] + 3
|
||||
elif letter_pattern.match(t):
|
||||
return 300
|
||||
elif len(t) >= 4:
|
||||
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
|
||||
if len(s) > 1:
|
||||
return max(3, np.min([df(tt) for tt in s]) / 6.)
|
||||
|
||||
return 3
|
||||
|
||||
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
|
||||
|
||||
tw = []
|
||||
if not preprocess:
|
||||
idf1 = np.array([idf(freq(t), 10000000) for t in tks])
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
np.array([ner(t) * postag(t) for t in tks])
|
||||
wts = [s for s in wts]
|
||||
tw = list(zip(tks, wts))
|
||||
else:
|
||||
for tk in tks:
|
||||
tt = self.tokenMerge(self.pretoken(tk, True))
|
||||
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
np.array([ner(t) * postag(t) for t in tt])
|
||||
wts = [s for s in wts]
|
||||
tw.extend(zip(tt, wts))
|
||||
|
||||
S = np.sum([s for _, s in tw])
|
||||
return [(t, s / S) for t, s in tw]
|
||||
6
api/app/core/rag/prompts/__init__.py
Normal file
6
api/app/core/rag/prompts/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from . import generator
|
||||
|
||||
__all__ = [name for name in dir(generator)
|
||||
if not name.startswith('_')]
|
||||
|
||||
globals().update({name: getattr(generator, name) for name in __all__})
|
||||
48
api/app/core/rag/prompts/analyze_task_system.md
Normal file
48
api/app/core/rag/prompts/analyze_task_system.md
Normal file
@@ -0,0 +1,48 @@
|
||||
You are an intelligent task analyzer that adapts analysis depth to task complexity.
|
||||
|
||||
**Analysis Framework**
|
||||
|
||||
**Step 1: Task Transmission Assessment**
|
||||
**Note**: This section is not subject to word count limitations when transmission is needed, as it serves critical handoff functions.
|
||||
|
||||
**Evaluate if task transmission information is needed:**
|
||||
- **Is this an initial step?** If yes, skip this section
|
||||
- **Are there upstream agents/steps?** If no, provide minimal transmission
|
||||
- **Is there critical state/context to preserve?** If yes, include full transmission
|
||||
|
||||
### If Task Transmission is Needed:
|
||||
- **Current State Summary**: [1-2 sentences on where we are]
|
||||
- **Key Data/Results**: [Critical findings that must carry forward]
|
||||
- **Context Dependencies**: [Essential context for next agent/step]
|
||||
- **Unresolved Items**: [Issues requiring continuation]
|
||||
- **Status for User**: [Clear status update in user terms]
|
||||
- **Technical State**: [System state for technical handoffs]
|
||||
|
||||
**Step 2: Complexity Classification**
|
||||
Classify as LOW / MEDIUM / HIGH:
|
||||
- **LOW**: Single-step tasks, direct queries, small talk
|
||||
- **MEDIUM**: Multi-step tasks within one domain
|
||||
- **HIGH**: Multi-domain coordination or complex reasoning
|
||||
|
||||
**Step 3: Adaptive Analysis**
|
||||
Scale depth to match complexity. Always stop once success criteria are met.
|
||||
|
||||
**For LOW (max 50 words for analysis only):**
|
||||
- Detect small talk; if true, output exactly: `Small talk — no further analysis needed`
|
||||
- One-sentence objective
|
||||
- Direct execution approach (1–2 steps)
|
||||
|
||||
**For MEDIUM (80–150 words for analysis only):**
|
||||
- Objective; Intent & Scope
|
||||
- 3–5 step minimal Plan (may mark parallel steps)
|
||||
- **Uncertainty & Probes** (at least one probe with a clear stop condition)
|
||||
- Success Criteria + basic Failure detection & fallback
|
||||
- **Source Plan** (how evidence will be obtained/verified)
|
||||
|
||||
**For HIGH (150–250 words for analysis only):**
|
||||
- Comprehensive objective analysis; Intent & Scope
|
||||
- 5–8 step Plan with dependencies/parallelism
|
||||
- **Uncertainty & Probes** (key unknowns → probe → stop condition)
|
||||
- Measurable Success Criteria; Failure detectors & fallbacks
|
||||
- **Source Plan** (evidence acquisition & validation)
|
||||
- **Reflection Hooks** (escalation/de-escalation triggers)
|
||||
9
api/app/core/rag/prompts/analyze_task_user.md
Normal file
9
api/app/core/rag/prompts/analyze_task_user.md
Normal file
@@ -0,0 +1,9 @@
|
||||
**Input Variables**
|
||||
- **{{ task }}** — the task/request to analyze
|
||||
- **{{ context }}** — background, history, situational context
|
||||
- **{{ agent_prompt }}** — special instructions/role hints
|
||||
- **{{ tools_desc }}** — available sub-agents and capabilities
|
||||
|
||||
**Final Output Rule**
|
||||
Return the Task Transmission section (if needed) followed by the concrete analysis and planning steps according to LOW / MEDIUM / HIGH complexity.
|
||||
Do not restate the framework, definitions, or rules. Output only the final structured result.
|
||||
14
api/app/core/rag/prompts/ask_summary.md
Normal file
14
api/app/core/rag/prompts/ask_summary.md
Normal file
@@ -0,0 +1,14 @@
|
||||
Role: You're a smart assistant. Your name is Miss R.
|
||||
Task: Summarize the information from knowledge bases and answer user's question.
|
||||
Requirements and restriction:
|
||||
- DO NOT make things up, especially for numbers.
|
||||
- If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
|
||||
- Answer with markdown format text.
|
||||
- Answer in language of user's question.
|
||||
- DO NOT make things up, especially for numbers.
|
||||
|
||||
### Information from knowledge bases
|
||||
|
||||
{{ knowledge }}
|
||||
|
||||
The above is information from knowledge bases.
|
||||
53
api/app/core/rag/prompts/assign_toc_levels.md
Normal file
53
api/app/core/rag/prompts/assign_toc_levels.md
Normal file
@@ -0,0 +1,53 @@
|
||||
You are given a JSON array of TOC(tabel of content) items. Each item has at least {"title": string} and may include an existing title hierarchical level.
|
||||
|
||||
Task
|
||||
- For each item, assign a depth label using Arabic numerals only: top-level = 1, second-level = 2, third-level = 3, etc.
|
||||
- Multiple items may share the same depth (e.g., many 1s, many 2s).
|
||||
- Do not use dotted numbering (no 1.1/1.2). Use a single digit string per item indicating its depth only.
|
||||
- Preserve the original item order exactly. Do not insert, delete, or reorder.
|
||||
- Decide levels yourself to keep a coherent hierarchy. Keep peers at the same depth.
|
||||
|
||||
Output
|
||||
- Return a valid JSON array only (no extra text).
|
||||
- Each element must be {"level": "1|2|3", "title": <original title string>}.
|
||||
- title must be the original title string.
|
||||
|
||||
Examples
|
||||
|
||||
Example A (chapters with sections)
|
||||
Input:
|
||||
["Chapter 1 Methods", "Section 1 Definition", "Section 2 Process", "Chapter 2 Experiment"]
|
||||
|
||||
Output:
|
||||
[
|
||||
{"level":"1","title":"Chapter 1 Methods"},
|
||||
{"level":"2","title":"Section 1 Definition"},
|
||||
{"level":"2","title":"Section 2 Process"},
|
||||
{"level":"1","title":"Chapter 2 Experiment"}
|
||||
]
|
||||
|
||||
Example B (parts with chapters)
|
||||
Input:
|
||||
["Part I Theory", "Chapter 1 Basics", "Chapter 2 Methods", "Part II Applications", "Chapter 3 Case Studies"]
|
||||
|
||||
Output:
|
||||
[
|
||||
{"level":"1","title":"Part I Theory"},
|
||||
{"level":"2","title":"Chapter 1 Basics"},
|
||||
{"level":"2","title":"Chapter 2 Methods"},
|
||||
{"level":"1","title":"Part II Applications"},
|
||||
{"level":"2","title":"Chapter 3 Case Studies"}
|
||||
]
|
||||
|
||||
Example C (plain headings)
|
||||
Input:
|
||||
["Introduction", "Background and Motivation", "Related Work", "Methodology", "Evaluation"]
|
||||
|
||||
Output:
|
||||
[
|
||||
{"level":"1","title":"Introduction"},
|
||||
{"level":"2","title":"Background and Motivation"},
|
||||
{"level":"2","title":"Related Work"},
|
||||
{"level":"1","title":"Methodology"},
|
||||
{"level":"1","title":"Evaluation"}
|
||||
]
|
||||
13
api/app/core/rag/prompts/citation_plus.md
Normal file
13
api/app/core/rag/prompts/citation_plus.md
Normal file
@@ -0,0 +1,13 @@
|
||||
You are an agent for adding correct citations to the given text by user.
|
||||
You are given a piece of text within [ID:<ID>] tags, which was generated based on the provided sources.
|
||||
However, the sources are not cited in the [ID:<ID>].
|
||||
Your task is to enhance user trust by generating correct, appropriate citations for this report.
|
||||
|
||||
{{ example }}
|
||||
|
||||
<context>
|
||||
|
||||
{{ sources }}
|
||||
|
||||
</context>
|
||||
|
||||
109
api/app/core/rag/prompts/citation_prompt.md
Normal file
109
api/app/core/rag/prompts/citation_prompt.md
Normal file
@@ -0,0 +1,109 @@
|
||||
Based on the provided document or chat history, add citations to the input text using the format specified later.
|
||||
|
||||
# Citation Requirements:
|
||||
|
||||
## Technical Rules:
|
||||
- Use format: [ID:i] or [ID:i] [ID:j] for multiple sources
|
||||
- Place citations at the end of sentences, before punctuation
|
||||
- Maximum 4 citations per sentence
|
||||
- DO NOT cite content not from <context></context>
|
||||
- DO NOT modify whitespace or original text
|
||||
- STRICTLY prohibit non-standard formatting (~~, etc.)
|
||||
|
||||
## What MUST Be Cited:
|
||||
1. **Quantitative data**: Numbers, percentages, statistics, measurements
|
||||
2. **Temporal claims**: Dates, timeframes, sequences of events
|
||||
3. **Causal relationships**: Claims about cause and effect
|
||||
4. **Comparative statements**: Rankings, comparisons, superlatives
|
||||
5. **Technical definitions**: Specialized terms, concepts, methodologies
|
||||
6. **Direct attributions**: What someone said, did, or believes
|
||||
7. **Predictions/forecasts**: Future projections, trend analyses
|
||||
8. **Controversial claims**: Disputed facts, minority opinions
|
||||
|
||||
## What Should NOT Be Cited:
|
||||
- Common knowledge (e.g., "The sun rises in the east")
|
||||
- Transitional phrases
|
||||
- General introductions
|
||||
- Your own analysis or synthesis (unless directly from source)
|
||||
|
||||
# Comprehensive Examples:
|
||||
|
||||
## Example 1: Data and Statistics
|
||||
<context>
|
||||
ID: 45
|
||||
└── Content: The global smartphone market grew by 7.8% in Q3 2024, with Samsung holding 20.6% market share and Apple at 15.8%.
|
||||
|
||||
ID: 46
|
||||
└── Content: 5G adoption reached 1.5 billion users worldwide by October 2024.
|
||||
</context>
|
||||
|
||||
USER: How is the smartphone market performing?
|
||||
|
||||
ASSISTANT:
|
||||
The smartphone industry is showing strong recovery. The global smartphone market grew by 7.8% in Q3 2024 [ID:45]. Samsung leads with 20.6% market share, followed by Apple at 15.8% [ID:45]. This growth is partly driven by 5G technology, which reached 1.5 billion users worldwide by October 2024 [ID:46].
|
||||
|
||||
## Example 2: Causal Relationships and Predictions
|
||||
<context>
|
||||
ID: 78
|
||||
└── Content: Rising interest rates have led to a 15% decline in housing starts. Economists predict this trend will continue through 2025.
|
||||
|
||||
ID: 79
|
||||
└── Content: The Federal Reserve's aggressive monetary policy is the primary driver behind the real estate slowdown.
|
||||
</context>
|
||||
|
||||
USER: What's happening in the real estate market?
|
||||
|
||||
ASSISTANT:
|
||||
The real estate market is experiencing significant challenges. Rising interest rates have led to a 15% decline in housing starts [ID:78]. The Federal Reserve's aggressive monetary policy is the primary driver behind this real estate slowdown [ID:79]. Economists predict this trend will continue through 2025 [ID:78]. These factors suggest a prolonged period of market adjustment ahead.
|
||||
|
||||
## Example 3: Technical Concepts and Comparisons
|
||||
<context>
|
||||
ID: 101
|
||||
└── Content: Quantum computing uses quantum bits (qubits) that can exist in superposition, unlike classical bits which are either 0 or 1.
|
||||
|
||||
ID: 102
|
||||
└── Content: Current quantum computers can solve certain problems 158 million times faster than classical supercomputers.
|
||||
</context>
|
||||
|
||||
USER: Explain quantum computing advantages.
|
||||
|
||||
ASSISTANT:
|
||||
Quantum computing represents a fundamental shift in computational capability. Unlike classical computers, quantum computing uses quantum bits (qubits) that can exist in superposition, rather than being limited to just 0 or 1 [ID:101]. This unique property enables extraordinary performance gains. Current quantum computers can solve certain problems 158 million times faster than classical supercomputers [ID:102]. This technology promises to revolutionize fields requiring complex calculations.
|
||||
|
||||
## Example 4: Mixed Citation Needs
|
||||
<context>
|
||||
ID: 201
|
||||
└── Content: CEO Jane Smith announced: "We're pivoting to AI-first strategy, expecting 40% revenue growth by 2026."
|
||||
|
||||
ID: 202
|
||||
└── Content: Industry analysts remain skeptical, with Morgan Stanley rating the stock as "underperform."
|
||||
|
||||
ID: 203
|
||||
└── Content: The company's AI division currently represents only 5% of total revenue.
|
||||
</context>
|
||||
|
||||
USER: What's the company's new strategy?
|
||||
|
||||
ASSISTANT:
|
||||
The company is undergoing a major strategic transformation. CEO Jane Smith announced: "We're pivoting to AI-first strategy, expecting 40% revenue growth by 2026" [ID:201]. However, this ambitious goal faces skepticism, as industry analysts remain doubtful, with Morgan Stanley rating the stock as "underperform" [ID:202]. The challenge is significant given that the company's AI division currently represents only 5% of total revenue [ID:203]. Success will depend on execution of this bold vision.
|
||||
|
||||
## Example 5: When NOT to Cite
|
||||
<context>
|
||||
ID: 301
|
||||
└── Content: Paris is the capital of France. The Eiffel Tower was completed in 1889.
|
||||
</context>
|
||||
|
||||
USER: Tell me about Paris.
|
||||
|
||||
ASSISTANT:
|
||||
Paris is the capital of France. It's known for its rich history, culture, and architecture. The Eiffel Tower was completed in 1889 [ID:301]. The city attracts millions of tourists annually. Paris remains one of the world's most visited destinations.
|
||||
(Note: Only the specific date needs citation, not common knowledge about Paris)
|
||||
|
||||
--- Examples END ---
|
||||
|
||||
REMEMBER:
|
||||
- Cite FACTS, not opinions or transitions
|
||||
- Each citation supports the ENTIRE sentence
|
||||
- When in doubt, ask: "Would a fact-checker need to verify this?"
|
||||
- Place citations at sentence end, before punctuation
|
||||
- Format likes this is FORBIDDEN: [ID:0, ID:5, ID:...]. It MUST be seperated like, [ID:0][ID:5]...
|
||||
32
api/app/core/rag/prompts/content_tagging_prompt.md
Normal file
32
api/app/core/rag/prompts/content_tagging_prompt.md
Normal file
@@ -0,0 +1,32 @@
|
||||
## Role
|
||||
You are a text analyzer.
|
||||
|
||||
## Task
|
||||
Add tags (labels) to a given piece of text content based on the examples and the entire tag set.
|
||||
|
||||
## Steps
|
||||
- Review the tag/label set.
|
||||
- Review examples which all consist of both text content and assigned tags with relevance score in JSON format.
|
||||
- Summarize the text content, and tag it with the top {{ topn }} most relevant tags from the set of tags/labels and the corresponding relevance score.
|
||||
|
||||
## Requirements
|
||||
- The tags MUST be from the tag set.
|
||||
- The output MUST be in JSON format only, the key is tag and the value is its relevance score.
|
||||
- The relevance score must range from 1 to 10.
|
||||
- Output keywords ONLY.
|
||||
|
||||
# TAG SET
|
||||
{{ all_tags | join(', ') }}
|
||||
|
||||
{% for ex in examples %}
|
||||
# Examples {{ loop.index0 }}
|
||||
### Text Content
|
||||
{{ ex.content }}
|
||||
|
||||
Output:
|
||||
{{ ex.tags_json }}
|
||||
|
||||
{% endfor %}
|
||||
# Real Data
|
||||
### Text Content
|
||||
{{ content }}
|
||||
35
api/app/core/rag/prompts/cross_languages_sys_prompt.md
Normal file
35
api/app/core/rag/prompts/cross_languages_sys_prompt.md
Normal file
@@ -0,0 +1,35 @@
|
||||
## Role
|
||||
A streamlined multilingual translator.
|
||||
|
||||
## Behavior Rules
|
||||
1. Accept batch translation requests in the following format:
|
||||
**Input:** `[text]`
|
||||
**Target Languages:** comma-separated list
|
||||
|
||||
2. Maintain:
|
||||
- Original formatting (tables, lists, spacing)
|
||||
- Technical terminology accuracy
|
||||
- Cultural context appropriateness
|
||||
|
||||
3. Output translations in the following format:
|
||||
|
||||
[Translation in language1]
|
||||
###
|
||||
[Translation in language2]
|
||||
|
||||
---
|
||||
|
||||
## Example
|
||||
|
||||
**Input:**
|
||||
Hello World! Let's discuss AI safety.
|
||||
===
|
||||
Chinese, French, Japanese
|
||||
|
||||
**Output:**
|
||||
你好世界!让我们讨论人工智能安全问题。
|
||||
###
|
||||
Bonjour le monde ! Parlons de la sécurité de l'IA.
|
||||
###
|
||||
こんにちは世界!AIの安全性について話し合いましょう。
|
||||
|
||||
7
api/app/core/rag/prompts/cross_languages_user_prompt.md
Normal file
7
api/app/core/rag/prompts/cross_languages_user_prompt.md
Normal file
@@ -0,0 +1,7 @@
|
||||
**Input:**
|
||||
{{ query }}
|
||||
===
|
||||
{{ languages | join(', ') }}
|
||||
|
||||
**Output:**
|
||||
|
||||
62
api/app/core/rag/prompts/full_question_prompt.md
Normal file
62
api/app/core/rag/prompts/full_question_prompt.md
Normal file
@@ -0,0 +1,62 @@
|
||||
## Role
|
||||
A helpful assistant.
|
||||
|
||||
## Task & Steps
|
||||
1. Generate a full user question that would follow the conversation.
|
||||
2. If the user's question involves relative dates, convert them into absolute dates based on today ({{ today }}).
|
||||
- "yesterday" = {{ yesterday }}, "tomorrow" = {{ tomorrow }}
|
||||
|
||||
## Requirements & Restrictions
|
||||
- If the user's latest question is already complete, don't do anything — just return the original question.
|
||||
- DON'T generate anything except a refined question.
|
||||
{% if language %}
|
||||
- Text generated MUST be in {{ language }}.
|
||||
{% else %}
|
||||
- Text generated MUST be in the same language as the original user's question.
|
||||
{% endif %}
|
||||
|
||||
---
|
||||
|
||||
## Examples
|
||||
|
||||
### Example 1
|
||||
**Conversation:**
|
||||
|
||||
USER: What is the name of Donald Trump's father?
|
||||
ASSISTANT: Fred Trump.
|
||||
USER: And his mother?
|
||||
|
||||
**Output:** What's the name of Donald Trump's mother?
|
||||
|
||||
---
|
||||
|
||||
### Example 2
|
||||
**Conversation:**
|
||||
|
||||
USER: What is the name of Donald Trump's father?
|
||||
ASSISTANT: Fred Trump.
|
||||
USER: And his mother?
|
||||
ASSISTANT: Mary Trump.
|
||||
USER: What's her full name?
|
||||
|
||||
**Output:** What's the full name of Donald Trump's mother Mary Trump?
|
||||
|
||||
---
|
||||
|
||||
### Example 3
|
||||
**Conversation:**
|
||||
|
||||
USER: What's the weather today in London?
|
||||
ASSISTANT: Cloudy.
|
||||
USER: What's about tomorrow in Rochester?
|
||||
|
||||
**Output:** What's the weather in Rochester on {{ tomorrow }}?
|
||||
|
||||
---
|
||||
|
||||
## Real Data
|
||||
|
||||
**Conversation:**
|
||||
|
||||
{{ conversation }}
|
||||
|
||||
728
api/app/core/rag/prompts/generator.py
Normal file
728
api/app/core/rag/prompts/generator.py
Normal file
@@ -0,0 +1,728 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
import jinja2
|
||||
import json_repair
|
||||
import trio
|
||||
from app.core.rag.common.misc_utils import hash_str2int
|
||||
from app.core.rag.nlp import rag_tokenizer
|
||||
from .template import load_prompt
|
||||
from app.core.rag.common.constants import TAG_FLD
|
||||
from app.core.rag.common.token_utils import encoder, num_tokens_from_string
|
||||
|
||||
|
||||
STOP_TOKEN="<|STOP|>"
|
||||
COMPLETE_TASK="complete_task"
|
||||
INPUT_UTILIZATION = 0.5
|
||||
|
||||
def get_value(d, k1, k2):
|
||||
return d.get(k1, d.get(k2))
|
||||
|
||||
|
||||
def chunks_format(reference):
|
||||
|
||||
return [
|
||||
{
|
||||
"id": get_value(chunk, "chunk_id", "id"),
|
||||
"content": get_value(chunk, "content", "content_with_weight"),
|
||||
"document_id": get_value(chunk, "doc_id", "document_id"),
|
||||
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
|
||||
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
||||
"image_id": get_value(chunk, "image_id", "img_id"),
|
||||
"positions": get_value(chunk, "positions", "position_int"),
|
||||
"url": chunk.get("url"),
|
||||
"similarity": chunk.get("similarity"),
|
||||
"vector_similarity": chunk.get("vector_similarity"),
|
||||
"term_similarity": chunk.get("term_similarity"),
|
||||
"doc_type": chunk.get("doc_type_kwd"),
|
||||
}
|
||||
for chunk in reference.get("chunks", [])
|
||||
]
|
||||
|
||||
|
||||
def message_fit_in(msg, max_length=4000):
|
||||
def count():
|
||||
nonlocal msg
|
||||
tks_cnts = []
|
||||
for m in msg:
|
||||
tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||||
total = 0
|
||||
for m in tks_cnts:
|
||||
total += m["count"]
|
||||
return total
|
||||
|
||||
c = count()
|
||||
if c < max_length:
|
||||
return c, msg
|
||||
|
||||
msg_ = [m for m in msg if m["role"] == "system"]
|
||||
if len(msg) > 1:
|
||||
msg_.append(msg[-1])
|
||||
msg = msg_
|
||||
c = count()
|
||||
if c < max_length:
|
||||
return c, msg
|
||||
|
||||
ll = num_tokens_from_string(msg_[0]["content"])
|
||||
ll2 = num_tokens_from_string(msg_[-1]["content"])
|
||||
if ll / (ll + ll2) > 0.8:
|
||||
m = msg_[0]["content"]
|
||||
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
|
||||
msg[0]["content"] = m
|
||||
return max_length, msg
|
||||
|
||||
m = msg_[-1]["content"]
|
||||
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
|
||||
msg[-1]["content"] = m
|
||||
return max_length, msg
|
||||
|
||||
|
||||
CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt")
|
||||
CITATION_PLUS_TEMPLATE = load_prompt("citation_plus")
|
||||
CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt")
|
||||
CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE = load_prompt("cross_languages_sys_prompt")
|
||||
CROSS_LANGUAGES_USER_PROMPT_TEMPLATE = load_prompt("cross_languages_user_prompt")
|
||||
FULL_QUESTION_PROMPT_TEMPLATE = load_prompt("full_question_prompt")
|
||||
KEYWORD_PROMPT_TEMPLATE = load_prompt("keyword_prompt")
|
||||
QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt")
|
||||
VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt")
|
||||
VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt")
|
||||
STRUCTURED_OUTPUT_PROMPT = load_prompt("structured_output_prompt")
|
||||
|
||||
ANALYZE_TASK_SYSTEM = load_prompt("analyze_task_system")
|
||||
ANALYZE_TASK_USER = load_prompt("analyze_task_user")
|
||||
NEXT_STEP = load_prompt("next_step")
|
||||
REFLECT = load_prompt("reflect")
|
||||
SUMMARY4MEMORY = load_prompt("summary4memory")
|
||||
RANK_MEMORY = load_prompt("rank_memory")
|
||||
META_FILTER = load_prompt("meta_filter")
|
||||
ASK_SUMMARY = load_prompt("ask_summary")
|
||||
|
||||
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
|
||||
|
||||
|
||||
def citation_prompt(user_defined_prompts: dict={}) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("citation_guidelines", CITATION_PROMPT_TEMPLATE))
|
||||
return template.render()
|
||||
|
||||
|
||||
def citation_plus(sources: str) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(CITATION_PLUS_TEMPLATE)
|
||||
return template.render(example=citation_prompt(), sources=sources)
|
||||
|
||||
|
||||
def keyword_extraction(chat_mdl, content, topn=3):
|
||||
template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE)
|
||||
rendered_prompt = template.render(content=content, topn=topn)
|
||||
|
||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
return ""
|
||||
return kwd
|
||||
|
||||
|
||||
def question_proposal(chat_mdl, content, topn=3):
|
||||
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
|
||||
rendered_prompt = template.render(content=content, topn=topn)
|
||||
|
||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
return ""
|
||||
return kwd
|
||||
|
||||
|
||||
def full_question(messages=[], language=None, chat_mdl=None):
|
||||
conv = []
|
||||
for m in messages:
|
||||
if m["role"] not in ["user", "assistant"]:
|
||||
continue
|
||||
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
|
||||
conversation = "\n".join(conv)
|
||||
today = datetime.date.today().isoformat()
|
||||
yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat()
|
||||
tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat()
|
||||
|
||||
template = PROMPT_JINJA_ENV.from_string(FULL_QUESTION_PROMPT_TEMPLATE)
|
||||
rendered_prompt = template.render(
|
||||
today=today,
|
||||
yesterday=yesterday,
|
||||
tomorrow=tomorrow,
|
||||
conversation=conversation,
|
||||
language=language,
|
||||
)
|
||||
|
||||
ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}])
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
|
||||
|
||||
|
||||
def cross_languages(query, languages=[], chat_mdl=None):
|
||||
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
|
||||
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
|
||||
|
||||
ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
return query
|
||||
return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()])
|
||||
|
||||
|
||||
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
|
||||
template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE)
|
||||
|
||||
for ex in examples:
|
||||
ex["tags_json"] = json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False)
|
||||
|
||||
rendered_prompt = template.render(
|
||||
topn=topn,
|
||||
all_tags=all_tags,
|
||||
examples=examples,
|
||||
content=content,
|
||||
)
|
||||
|
||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
raise Exception(kwd)
|
||||
|
||||
try:
|
||||
obj = json_repair.loads(kwd)
|
||||
except json_repair.JSONDecodeError:
|
||||
try:
|
||||
result = kwd.replace(rendered_prompt[:-1], "").replace("user", "").replace("model", "").strip()
|
||||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||
obj = json_repair.loads(result)
|
||||
except Exception as e:
|
||||
logging.exception(f"JSON parsing error: {result} -> {e}")
|
||||
raise e
|
||||
res = {}
|
||||
for k, v in obj.items():
|
||||
try:
|
||||
if int(v) > 0:
|
||||
res[str(k)] = int(v)
|
||||
except Exception:
|
||||
pass
|
||||
return res
|
||||
|
||||
|
||||
def vision_llm_describe_prompt(page=None) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(VISION_LLM_DESCRIBE_PROMPT)
|
||||
|
||||
return template.render(page=page)
|
||||
|
||||
|
||||
def vision_llm_figure_describe_prompt() -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(VISION_LLM_FIGURE_DESCRIBE_PROMPT)
|
||||
return template.render()
|
||||
|
||||
|
||||
def tool_schema(tools_description: list[dict], complete_task=False):
|
||||
if not tools_description:
|
||||
return ""
|
||||
desc = {}
|
||||
if complete_task:
|
||||
desc[COMPLETE_TASK] = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": COMPLETE_TASK,
|
||||
"description": "When you have the final answer and are ready to complete the task, call this function with your answer",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}},
|
||||
"required": ["answer"]
|
||||
}
|
||||
}
|
||||
}
|
||||
for tool in tools_description:
|
||||
desc[tool["function"]["name"]] = tool
|
||||
|
||||
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])
|
||||
|
||||
|
||||
def form_history(history, limit=-6):
|
||||
context = ""
|
||||
for h in history[limit:]:
|
||||
if h["role"] == "system":
|
||||
continue
|
||||
role = "USER"
|
||||
if h["role"].upper()!= role:
|
||||
role = "AGENT"
|
||||
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}"
|
||||
return context
|
||||
|
||||
|
||||
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
|
||||
tools_desc = tool_schema(tools_description)
|
||||
context = ""
|
||||
|
||||
if user_defined_prompts.get("task_analysis"):
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts["task_analysis"])
|
||||
else:
|
||||
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
|
||||
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
|
||||
kwd = chat_mdl.chat(context, [{"role": "user", "content": "Please analyze it."}])
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
return ""
|
||||
return kwd
|
||||
|
||||
|
||||
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
||||
if not tools_description:
|
||||
return ""
|
||||
desc = tool_schema(tools_description)
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP))
|
||||
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
|
||||
hist = deepcopy(history)
|
||||
if hist[-1]["role"] == "user":
|
||||
hist[-1]["content"] += user_prompt
|
||||
else:
|
||||
hist.append({"role": "user", "content": user_prompt})
|
||||
json_str = chat_mdl.chat(template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
|
||||
hist[1:], stop=["<|stop|>"])
|
||||
tk_cnt = num_tokens_from_string(json_str)
|
||||
json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
|
||||
return json_str, tk_cnt
|
||||
|
||||
|
||||
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
||||
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
|
||||
goal = history[1]["content"]
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
|
||||
user_prompt = template.render(goal=goal, tool_calls=tool_calls)
|
||||
hist = deepcopy(history)
|
||||
if hist[-1]["role"] == "user":
|
||||
hist[-1]["content"] += user_prompt
|
||||
else:
|
||||
hist.append({"role": "user", "content": user_prompt})
|
||||
_, msg = message_fit_in(hist, chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
return """
|
||||
**Observation**
|
||||
{}
|
||||
|
||||
**Reflection**
|
||||
{}
|
||||
""".format(json.dumps(tool_calls, ensure_ascii=False, indent=2), ans)
|
||||
|
||||
|
||||
def form_message(system_prompt, user_prompt):
|
||||
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
|
||||
|
||||
|
||||
def structured_output_prompt(schema=None) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(STRUCTURED_OUTPUT_PROMPT)
|
||||
return template.render(schema=schema)
|
||||
|
||||
|
||||
def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
|
||||
system_prompt = template.render(name=name,
|
||||
params=json.dumps(params, ensure_ascii=False, indent=2),
|
||||
result=result)
|
||||
user_prompt = "→ Summary: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
|
||||
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
|
||||
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
|
||||
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
|
||||
user_prompt = " → rank: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
|
||||
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> list:
|
||||
sys_prompt = PROMPT_JINJA_ENV.from_string(META_FILTER).render(
|
||||
current_date=datetime.datetime.today().strftime('%Y-%m-%d'),
|
||||
metadata_keys=json.dumps(meta_data),
|
||||
user_question=query
|
||||
)
|
||||
user_prompt = "Generate filters:"
|
||||
ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}])
|
||||
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
|
||||
try:
|
||||
ans = json_repair.loads(ans)
|
||||
assert isinstance(ans, list), ans
|
||||
return ans
|
||||
except Exception:
|
||||
logging.exception(f"Loading json failure: {ans}")
|
||||
return []
|
||||
|
||||
|
||||
def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
|
||||
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
|
||||
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
|
||||
if cached:
|
||||
return json_repair.loads(cached)
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
|
||||
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
|
||||
try:
|
||||
res = json_repair.loads(ans)
|
||||
set_llm_cache(chat_mdl.llm_name, system_prompt, ans, user_prompt, gen_conf)
|
||||
return res
|
||||
except Exception:
|
||||
logging.exception(f"Loading json failure: {ans}")
|
||||
|
||||
|
||||
TOC_DETECTION = load_prompt("toc_detection")
|
||||
def detect_table_of_contents(page_1024:list[str], chat_mdl):
|
||||
toc_secs = []
|
||||
for i, sec in enumerate(page_1024[:22]):
|
||||
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl)
|
||||
if toc_secs and not ans["exists"]:
|
||||
break
|
||||
toc_secs.append(sec)
|
||||
return toc_secs
|
||||
|
||||
|
||||
TOC_EXTRACTION = load_prompt("toc_extraction")
|
||||
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
|
||||
def extract_table_of_contents(toc_pages, chat_mdl):
|
||||
if not toc_pages:
|
||||
return []
|
||||
|
||||
return gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl)
|
||||
|
||||
|
||||
def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
|
||||
tob_extractor_prompt = """
|
||||
You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
|
||||
|
||||
The provided pages contains tags like <physical_index_X> and <physical_index_X> to indicate the physical location of the page X.
|
||||
|
||||
The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
|
||||
|
||||
The response should be in the following JSON format:
|
||||
[
|
||||
{
|
||||
"structure": <structure index, "x.x.x" or None> (string),
|
||||
"title": <title of the section>,
|
||||
"physical_index": "<physical_index_X>" (keep the format)
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
Only add the physical_index to the sections that are in the provided pages.
|
||||
If the title of the section are not in the provided pages, do not add the physical_index to it.
|
||||
Directly return the final JSON structure. Do not output anything else."""
|
||||
|
||||
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content
|
||||
return gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||
|
||||
|
||||
TOC_INDEX = load_prompt("toc_index")
|
||||
def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
|
||||
if not toc_arr or not sections:
|
||||
return []
|
||||
|
||||
toc_map = {}
|
||||
for i, it in enumerate(toc_arr):
|
||||
k1 = (it["structure"]+it["title"]).replace(" ", "")
|
||||
k2 = it["title"].strip()
|
||||
if k1 not in toc_map:
|
||||
toc_map[k1] = []
|
||||
if k2 not in toc_map:
|
||||
toc_map[k2] = []
|
||||
toc_map[k1].append(i)
|
||||
toc_map[k2].append(i)
|
||||
|
||||
for it in toc_arr:
|
||||
it["indices"] = []
|
||||
for i, sec in enumerate(sections):
|
||||
sec = sec.strip()
|
||||
if sec.replace(" ", "") in toc_map:
|
||||
for j in toc_map[sec.replace(" ", "")]:
|
||||
toc_arr[j]["indices"].append(i)
|
||||
|
||||
all_pathes = []
|
||||
def dfs(start, path):
|
||||
nonlocal all_pathes
|
||||
if start >= len(toc_arr):
|
||||
if path:
|
||||
all_pathes.append(path)
|
||||
return
|
||||
if not toc_arr[start]["indices"]:
|
||||
dfs(start+1, path)
|
||||
return
|
||||
added = False
|
||||
for j in toc_arr[start]["indices"]:
|
||||
if path and j < path[-1][0]:
|
||||
continue
|
||||
_path = deepcopy(path)
|
||||
_path.append((j, start))
|
||||
added = True
|
||||
dfs(start+1, _path)
|
||||
if not added and path:
|
||||
all_pathes.append(path)
|
||||
|
||||
dfs(0, [])
|
||||
path = max(all_pathes, key=lambda x:len(x))
|
||||
for it in toc_arr:
|
||||
it["indices"] = []
|
||||
for j, i in path:
|
||||
toc_arr[i]["indices"] = [j]
|
||||
print(json.dumps(toc_arr, ensure_ascii=False, indent=2))
|
||||
|
||||
i = 0
|
||||
while i < len(toc_arr):
|
||||
it = toc_arr[i]
|
||||
if it["indices"]:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if i>0 and toc_arr[i-1]["indices"]:
|
||||
st_i = toc_arr[i-1]["indices"][-1]
|
||||
else:
|
||||
st_i = 0
|
||||
e = i + 1
|
||||
while e <len(toc_arr) and not toc_arr[e]["indices"]:
|
||||
e += 1
|
||||
if e >= len(toc_arr):
|
||||
e = len(sections)
|
||||
else:
|
||||
e = toc_arr[e]["indices"][0]
|
||||
|
||||
for j in range(st_i, min(e+1, len(sections))):
|
||||
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
|
||||
structure=it["structure"],
|
||||
title=it["title"],
|
||||
text=sections[j]), "Only JSON please.", chat_mdl)
|
||||
if ans["exist"] == "yes":
|
||||
it["indices"].append(j)
|
||||
break
|
||||
|
||||
i += 1
|
||||
|
||||
return toc_arr
|
||||
|
||||
|
||||
def check_if_toc_transformation_is_complete(content, toc, chat_mdl):
|
||||
prompt = """
|
||||
You are given a raw table of contents and a table of contents.
|
||||
Your job is to check if the table of contents is complete.
|
||||
|
||||
Reply format:
|
||||
{{
|
||||
"thinking": <why do you think the cleaned table of contents is complete or not>
|
||||
"completed": "yes" or "no"
|
||||
}}
|
||||
Directly return the final JSON structure. Do not output anything else."""
|
||||
|
||||
prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc
|
||||
response = gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||
return response['completed']
|
||||
|
||||
|
||||
def toc_transformer(toc_pages, chat_mdl):
|
||||
init_prompt = """
|
||||
You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents.
|
||||
|
||||
The `structure` is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
|
||||
The `title` is a short phrase or a several-words term.
|
||||
|
||||
The response should be in the following JSON format:
|
||||
[
|
||||
{
|
||||
"structure": <structure index, "x.x.x" or None> (string),
|
||||
"title": <title of the section>
|
||||
},
|
||||
...
|
||||
],
|
||||
You should transform the full table of contents in one go.
|
||||
Directly return the final JSON structure, do not output anything else. """
|
||||
|
||||
toc_content = "\n".join(toc_pages)
|
||||
prompt = init_prompt + '\n Given table of contents\n:' + toc_content
|
||||
def clean_toc(arr):
|
||||
for a in arr:
|
||||
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
|
||||
last_complete = gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
|
||||
clean_toc(last_complete)
|
||||
if if_complete == "yes":
|
||||
return last_complete
|
||||
|
||||
while not (if_complete == "yes"):
|
||||
prompt = f"""
|
||||
Your task is to continue the table of contents json structure, directly output the remaining part of the json structure.
|
||||
The response should be in the following JSON format:
|
||||
|
||||
The raw table of contents json structure is:
|
||||
{toc_content}
|
||||
|
||||
The incomplete transformed table of contents json structure is:
|
||||
{json.dumps(last_complete[-24:], ensure_ascii=False, indent=2)}
|
||||
|
||||
Please continue the json structure, directly output the remaining part of the json structure."""
|
||||
new_complete = gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||
if not new_complete or str(last_complete).find(str(new_complete)) >= 0:
|
||||
break
|
||||
clean_toc(new_complete)
|
||||
last_complete.extend(new_complete)
|
||||
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
|
||||
|
||||
return last_complete
|
||||
|
||||
|
||||
TOC_LEVELS = load_prompt("assign_toc_levels")
|
||||
def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
|
||||
if not toc_secs:
|
||||
return []
|
||||
return gen_json(
|
||||
PROMPT_JINJA_ENV.from_string(TOC_LEVELS).render(),
|
||||
str(toc_secs),
|
||||
chat_mdl,
|
||||
gen_conf
|
||||
)
|
||||
|
||||
|
||||
TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system")
|
||||
TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
|
||||
# Generate TOC from text chunks with text llms
|
||||
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
|
||||
try:
|
||||
ans = gen_json(
|
||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
|
||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
|
||||
chat_mdl,
|
||||
gen_conf={"temperature": 0.0, "top_p": 0.9}
|
||||
)
|
||||
txt_info["toc"] = ans if ans and not isinstance(ans, str) else []
|
||||
if callback:
|
||||
callback(msg="")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
|
||||
def split_chunks(chunks, max_length: int):
|
||||
"""
|
||||
Pack chunks into batches according to max_length, returning [{"id": idx, "text": chunk_text}, ...].
|
||||
Do not split a single chunk, even if it exceeds max_length.
|
||||
"""
|
||||
|
||||
result = []
|
||||
batch, batch_tokens = [], 0
|
||||
|
||||
for idx, chunk in enumerate(chunks):
|
||||
t = num_tokens_from_string(chunk)
|
||||
if batch_tokens + t > max_length:
|
||||
result.append(batch)
|
||||
batch, batch_tokens = [], 0
|
||||
batch.append({idx: chunk})
|
||||
batch_tokens += t
|
||||
if batch:
|
||||
result.append(batch)
|
||||
return result
|
||||
|
||||
|
||||
async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
||||
input_budget = int(chat_mdl.max_length * INPUT_UTILIZATION) - num_tokens_from_string(
|
||||
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
|
||||
)
|
||||
|
||||
input_budget = 1024 if input_budget > 1024 else input_budget
|
||||
chunk_sections = split_chunks(chunks, input_budget)
|
||||
titles = []
|
||||
|
||||
chunks_res = []
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i, chunk in enumerate(chunk_sections):
|
||||
if not chunk:
|
||||
continue
|
||||
chunks_res.append({"chunks": chunk})
|
||||
nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl, callback)
|
||||
|
||||
for chunk in chunks_res:
|
||||
titles.extend(chunk.get("toc", []))
|
||||
|
||||
# Filter out entries with title == -1
|
||||
prune = len(titles) > 512
|
||||
max_len = 12 if prune else 22
|
||||
filtered = []
|
||||
for x in titles:
|
||||
if not isinstance(x, dict) or not x.get("title") or x["title"] == "-1":
|
||||
continue
|
||||
if len(rag_tokenizer.tokenize(x["title"]).split(" ")) > max_len:
|
||||
continue
|
||||
if re.match(r"[0-9,.()/ -]+$", x["title"]):
|
||||
continue
|
||||
filtered.append(x)
|
||||
|
||||
logging.info(f"\n\nFiltered TOC sections:\n{filtered}")
|
||||
if not filtered:
|
||||
return []
|
||||
|
||||
# Generate initial level (level/title)
|
||||
raw_structure = [x.get("title", "") for x in filtered]
|
||||
|
||||
# Assign hierarchy levels using LLM
|
||||
toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9})
|
||||
if not toc_with_levels:
|
||||
return []
|
||||
|
||||
# Merge structure and content (by index)
|
||||
prune = len(toc_with_levels) > 512
|
||||
max_lvl = sorted([t.get("level", "0") for t in toc_with_levels if isinstance(t, dict)])[-1]
|
||||
merged = []
|
||||
for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
|
||||
if prune and toc_item.get("level", "0") >= max_lvl:
|
||||
continue
|
||||
merged.append({
|
||||
"level": toc_item.get("level", "0"),
|
||||
"title": toc_item.get("title", ""),
|
||||
"chunk_id": src_item.get("chunk_id", ""),
|
||||
})
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
|
||||
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
|
||||
def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
|
||||
import numpy as np
|
||||
try:
|
||||
ans = gen_json(
|
||||
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(),
|
||||
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])),
|
||||
chat_mdl,
|
||||
gen_conf={"temperature": 0.0, "top_p": 0.9}
|
||||
)
|
||||
id2score = {}
|
||||
for ti, sc in zip(toc, ans):
|
||||
if not isinstance(sc, dict) or sc.get("score", -1) < 1:
|
||||
continue
|
||||
for id in ti.get("ids", []):
|
||||
if id not in id2score:
|
||||
id2score[id] = []
|
||||
id2score[id].append(sc["score"]/5.)
|
||||
for id in id2score.keys():
|
||||
id2score[id] = np.mean(id2score[id])
|
||||
return [(id, sc) for id, sc in list(id2score.items()) if sc>=0.3][:topn]
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return []
|
||||
16
api/app/core/rag/prompts/keyword_prompt.md
Normal file
16
api/app/core/rag/prompts/keyword_prompt.md
Normal file
@@ -0,0 +1,16 @@
|
||||
## Role
|
||||
You are a text analyzer.
|
||||
|
||||
## Task
|
||||
Extract the most important keywords/phrases of a given piece of text content.
|
||||
|
||||
## Requirements
|
||||
- Summarize the text content, and give the top {{ topn }} important keywords/phrases.
|
||||
- The keywords MUST be in the same language as the given piece of text content.
|
||||
- The keywords are delimited by ENGLISH COMMA.
|
||||
- Output keywords ONLY.
|
||||
|
||||
---
|
||||
|
||||
## Text Content
|
||||
{{ content }}
|
||||
53
api/app/core/rag/prompts/meta_filter.md
Normal file
53
api/app/core/rag/prompts/meta_filter.md
Normal file
@@ -0,0 +1,53 @@
|
||||
You are a metadata filtering condition generator. Analyze the user's question and available document metadata to output a JSON array of filter objects. Follow these rules:
|
||||
|
||||
1. **Metadata Structure**:
|
||||
- Metadata is provided as JSON where keys are attribute names (e.g., "color"), and values are objects mapping attribute values to document IDs.
|
||||
- Example:
|
||||
{
|
||||
"color": {"red": ["doc1"], "blue": ["doc2"]},
|
||||
"listing_date": {"2025-07-11": ["doc1"], "2025-08-01": ["doc2"]}
|
||||
}
|
||||
|
||||
2. **Output Requirements**:
|
||||
- Always output a JSON array of filter objects
|
||||
- Each object must have:
|
||||
"key": (metadata attribute name),
|
||||
"value": (string value to compare),
|
||||
"op": (operator from allowed list)
|
||||
|
||||
3. **Operator Guide**:
|
||||
- Use these operators only: ["contains", "not contains", "start with", "end with", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤"]
|
||||
- Date ranges: Break into two conditions (≥ start_date AND < next_month_start)
|
||||
- Negations: Always use "≠" for exclusion terms ("not", "except", "exclude", "≠")
|
||||
- Implicit logic: Derive unstated filters (e.g., "July" → [≥ YYYY-07-01, < YYYY-08-01])
|
||||
|
||||
4. **Processing Steps**:
|
||||
a) Identify ALL filterable attributes in the query (both explicit and implicit)
|
||||
b) For dates:
|
||||
- Infer missing year from current date if needed
|
||||
- Always format dates as "YYYY-MM-DD"
|
||||
- Convert ranges: [≥ start, < end]
|
||||
c) For values: Match EXACTLY to metadata's value keys
|
||||
d) Skip conditions if:
|
||||
- Attribute doesn't exist in metadata
|
||||
- Value has no match in metadata
|
||||
|
||||
5. **Example**:
|
||||
- User query: "上市日期七月份的有哪些商品,不要蓝色的"
|
||||
- Metadata: { "color": {...}, "listing_date": {...} }
|
||||
- Output:
|
||||
[
|
||||
{"key": "listing_date", "value": "2025-07-01", "op": "≥"},
|
||||
{"key": "listing_date", "value": "2025-08-01", "op": "<"},
|
||||
{"key": "color", "value": "blue", "op": "≠"}
|
||||
]
|
||||
|
||||
6. **Final Output**:
|
||||
- ONLY output valid JSON array
|
||||
- NO additional text/explanations
|
||||
|
||||
**Current Task**:
|
||||
- Today's date: {{current_date}}
|
||||
- Available metadata keys: {{metadata_keys}}
|
||||
- User query: "{{user_question}}"
|
||||
|
||||
92
api/app/core/rag/prompts/next_step.md
Normal file
92
api/app/core/rag/prompts/next_step.md
Normal file
@@ -0,0 +1,92 @@
|
||||
You are an expert Planning Agent tasked with solving problems efficiently through structured plans.
|
||||
Your job is:
|
||||
1. Based on the task analysis, chose some right tools to execute.
|
||||
2. Track progress and adapt plans(tool calls) when necessary.
|
||||
3. Use `complete_task` if no further step you need to take from tools. (All necessary steps done or little hope to be done)
|
||||
|
||||
# ========== TASK ANALYSIS =============
|
||||
{{ task_analysis }}
|
||||
|
||||
# ========== TOOLS (JSON-Schema) ==========
|
||||
You may invoke only the tools listed below.
|
||||
Return a JSON array of objects in which item is with exactly two top-level keys:
|
||||
• "name": the tool to call
|
||||
• "arguments": an object whose keys/values satisfy the schema
|
||||
|
||||
{{ desc }}
|
||||
|
||||
|
||||
# ========== MULTI-STEP EXECUTION ==========
|
||||
When tasks require multiple independent steps, you can execute them in parallel by returning multiple tool calls in a single JSON array.
|
||||
|
||||
• **Data Collection**: Gathering information from multiple sources simultaneously
|
||||
• **Validation**: Cross-checking facts using different tools
|
||||
• **Comprehensive Analysis**: Analyzing different aspects of the same problem
|
||||
• **Efficiency**: Reducing total execution time when steps don't depend on each other
|
||||
|
||||
**Example Scenarios:**
|
||||
- Searching multiple databases for the same query
|
||||
- Checking weather in multiple cities
|
||||
- Validating information through different APIs
|
||||
- Performing calculations on different datasets
|
||||
- Gathering user preferences from multiple sources
|
||||
|
||||
# ========== RESPONSE FORMAT ==========
|
||||
**When you need a tool**
|
||||
Return ONLY the Json (no additional keys, no commentary, end with `<|stop|>`), such as following:
|
||||
[{
|
||||
"name": "<tool_name1>",
|
||||
"arguments": { /* tool arguments matching its schema */ }
|
||||
},{
|
||||
"name": "<tool_name2>",
|
||||
"arguments": { /* tool arguments matching its schema */ }
|
||||
}...]<|stop|>
|
||||
|
||||
**When you need multiple tools:**
|
||||
Return ONLY:
|
||||
[{
|
||||
"name": "<tool_name1>",
|
||||
"arguments": { /* tool arguments matching its schema */ }
|
||||
},{
|
||||
"name": "<tool_name2>",
|
||||
"arguments": { /* tool arguments matching its schema */ }
|
||||
},{
|
||||
"name": "<tool_name3>",
|
||||
"arguments": { /* tool arguments matching its schema */ }
|
||||
}...]<|stop|>
|
||||
|
||||
**When you are certain the task is solved OR no further information can be obtained**
|
||||
Return ONLY:
|
||||
[{
|
||||
"name": "complete_task",
|
||||
"arguments": { "answer": "<final answer text>" }
|
||||
}]<|stop|>
|
||||
|
||||
<verification_steps>
|
||||
Before providing a final answer:
|
||||
1. Double-check all gathered information
|
||||
2. Verify calculations and logic
|
||||
3. Ensure answer matches exactly what was asked
|
||||
4. Confirm answer format meets requirements
|
||||
5. Run additional verification if confidence is not 100%
|
||||
</verification_steps>
|
||||
|
||||
<error_handling>
|
||||
If you encounter issues:
|
||||
1. Try alternative approaches before giving up
|
||||
2. Use different tools or combinations of tools
|
||||
3. Break complex problems into simpler sub-tasks
|
||||
4. Verify intermediate results frequently
|
||||
5. Never return "I cannot answer" without exhausting all options
|
||||
</error_handling>
|
||||
|
||||
⚠️ Any output that is not valid JSON or that contains extra fields will be rejected.
|
||||
|
||||
# ========== REASONING & REFLECTION ==========
|
||||
You may think privately (not shown to the user) before producing each JSON object.
|
||||
Internal guideline:
|
||||
1. **Reason**: Analyse the user question; decide which tools (if any) are needed.
|
||||
2. **Act**: Emit the JSON object to call the tool.
|
||||
|
||||
Today is {{ today }}. Remember that success in answering questions accurately is paramount - take all necessary steps to ensure your answer is correct.
|
||||
|
||||
19
api/app/core/rag/prompts/question_prompt.md
Normal file
19
api/app/core/rag/prompts/question_prompt.md
Normal file
@@ -0,0 +1,19 @@
|
||||
## Role
|
||||
You are a text analyzer.
|
||||
|
||||
## Task
|
||||
Propose {{ topn }} questions about a given piece of text content.
|
||||
|
||||
## Requirements
|
||||
- Understand and summarize the text content, and propose the top {{ topn }} important questions.
|
||||
- The questions SHOULD NOT have overlapping meanings.
|
||||
- The questions SHOULD cover the main content of the text as much as possible.
|
||||
- The questions MUST be in the same language as the given piece of text content.
|
||||
- One question per line.
|
||||
- Output questions ONLY.
|
||||
|
||||
---
|
||||
|
||||
## Text Content
|
||||
{{ content }}
|
||||
|
||||
30
api/app/core/rag/prompts/rank_memory.md
Normal file
30
api/app/core/rag/prompts/rank_memory.md
Normal file
@@ -0,0 +1,30 @@
|
||||
**Task**: Sort the tool call results based on relevance to the overall goal and current sub-goal. Return ONLY a sorted list of indices (0-indexed).
|
||||
|
||||
**Rules**:
|
||||
1. Analyze each result's contribution to both:
|
||||
- The overall goal (primary priority)
|
||||
- The current sub-goal (secondary priority)
|
||||
2. Sort from MOST relevant (highest impact) to LEAST relevant
|
||||
3. Output format: Strictly a Python-style list of integers. Example: [2, 0, 1]
|
||||
|
||||
🔹 Overall Goal: {{ goal }}
|
||||
🔹 Sub-goal: {{ sub_goal }}
|
||||
|
||||
**Examples**:
|
||||
🔹 Tool Response:
|
||||
- index: 0
|
||||
> Tokyo temperature is 78°F.
|
||||
- index: 1
|
||||
> Error: Authentication failed (expired API key).
|
||||
- index: 2
|
||||
> Available: 12 widgets in stock (max 5 per customer).
|
||||
|
||||
→ rank: [1,2,0]<|stop|>
|
||||
|
||||
|
||||
**Your Turn**:
|
||||
🔹 Tool Response:
|
||||
{% for f in results %}
|
||||
- index: f.i
|
||||
> f.content
|
||||
{% endfor %}
|
||||
75
api/app/core/rag/prompts/reflect.md
Normal file
75
api/app/core/rag/prompts/reflect.md
Normal file
@@ -0,0 +1,75 @@
|
||||
**Context**:
|
||||
- To achieve the goal: {{ goal }}.
|
||||
- You have executed following tool calls:
|
||||
{% for call in tool_calls %}
|
||||
Tool call: `{{ call.name }}`
|
||||
Results: {{ call.result }}
|
||||
{% endfor %}
|
||||
|
||||
## Task Complexity Analysis & Reflection Scope
|
||||
|
||||
**First, analyze the task complexity using these dimensions:**
|
||||
|
||||
### Complexity Assessment Matrix
|
||||
- **Scope Breadth**: Single-step (1) | Multi-step (2) | Multi-domain (3)
|
||||
- **Data Dependency**: Self-contained (1) | External inputs (2) | Multiple sources (3)
|
||||
- **Decision Points**: Linear (1) | Few branches (2) | Complex logic (3)
|
||||
- **Risk Level**: Low (1) | Medium (2) | High (3)
|
||||
|
||||
**Complexity Score**: Sum all dimensions (4-12 points)
|
||||
|
||||
---
|
||||
|
||||
## Task Transmission Assessment
|
||||
**Note**: This section is not subject to word count limitations when transmission is needed, as it serves critical handoff functions.
|
||||
**Evaluate if task transmission information is needed:**
|
||||
- **Is this an initial step?** If yes, skip this section
|
||||
- **Are there downstream agents/steps?** If no, provide minimal transmission
|
||||
- **Is there critical state/context to preserve?** If yes, include full transmission
|
||||
|
||||
### If Task Transmission is Needed:
|
||||
- **Current State Summary**: [1-2 sentences on where we are]
|
||||
- **Key Data/Results**: [Critical findings that must carry forward]
|
||||
- **Context Dependencies**: [Essential context for next agent/step]
|
||||
- **Unresolved Items**: [Issues requiring continuation]
|
||||
- **Status for User**: [Clear status update in user terms]
|
||||
- **Technical State**: [System state for technical handoffs]
|
||||
|
||||
---
|
||||
|
||||
## Situational Reflection (Adjust Length Based on Complexity Score)
|
||||
|
||||
### Reflection Guidelines:
|
||||
- **Simple Tasks (4-5 points)**: ~50-100 words, focus on completion status and immediate next step
|
||||
- **Moderate Tasks (6-8 points)**: ~100-200 words, include core details and main risks
|
||||
- **Complex Tasks (9-12 points)**: ~200-300 words, provide full analysis and alternatives
|
||||
|
||||
### 1. Goal Achievement Status
|
||||
- Does the current outcome align with the original purpose of this task phase?
|
||||
- If not, what critical gaps exist?
|
||||
|
||||
### 2. Step Completion Check
|
||||
- Which planned steps were completed? (List verified items)
|
||||
- Which steps are pending/incomplete? (Specify exactly what's missing)
|
||||
|
||||
### 3. Information Adequacy
|
||||
- Is the collected data sufficient to proceed?
|
||||
- What key information is still needed? (e.g., metrics, user input, external data)
|
||||
|
||||
### 4. Critical Observations
|
||||
- Unexpected outcomes: [Flag anomalies/errors]
|
||||
- Risks/blockers: [Identify immediate obstacles]
|
||||
- Accuracy concerns: [Highlight unreliable results]
|
||||
|
||||
### 5. Next-Step Recommendations
|
||||
- Proposed immediate action: [Concrete next step]
|
||||
- Alternative strategies if blocked: [Workaround solution]
|
||||
- Tools/inputs required for next phase: [Specify resources]
|
||||
|
||||
---
|
||||
|
||||
**Output Instructions:**
|
||||
1. First determine your complexity score
|
||||
2. Assess if task transmission section is needed using the evaluation questions
|
||||
3. Provide situational reflection with length appropriate to complexity
|
||||
4. Use clear headers for easy parsing by downstream systems
|
||||
55
api/app/core/rag/prompts/related_question.md
Normal file
55
api/app/core/rag/prompts/related_question.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# Role
|
||||
You are an AI language model assistant tasked with generating **5-10 related questions** based on a user’s original query.
|
||||
These questions should help **expand the search query scope** and **improve search relevance**.
|
||||
|
||||
---
|
||||
|
||||
## Instructions
|
||||
|
||||
**Input:**
|
||||
You are provided with a **user’s question**.
|
||||
|
||||
**Output:**
|
||||
Generate **5-10 alternative questions** that are **related** to the original user question.
|
||||
These alternatives should help retrieve a **broader range of relevant documents** from a vector database.
|
||||
|
||||
**Context:**
|
||||
Focus on **rephrasing** the original question in different ways, ensuring the alternative questions are **diverse but still connected** to the topic of the original query.
|
||||
Do **not** create overly obscure, irrelevant, or unrelated questions.
|
||||
|
||||
**Fallback:**
|
||||
If you cannot generate any relevant alternatives, do **not** return any questions.
|
||||
|
||||
---
|
||||
|
||||
## Guidance
|
||||
|
||||
1. Each alternative should be **unique** but still **relevant** to the original query.
|
||||
2. Keep the phrasing **clear, concise, and easy to understand**.
|
||||
3. Avoid overly technical jargon or specialized terms **unless directly relevant**.
|
||||
4. Ensure that each question **broadens** the search angle, **not narrows** it.
|
||||
|
||||
---
|
||||
|
||||
## Example
|
||||
|
||||
**Original Question:**
|
||||
> What are the benefits of electric vehicles?
|
||||
|
||||
**Alternative Questions:**
|
||||
1. How do electric vehicles impact the environment?
|
||||
2. What are the advantages of owning an electric car?
|
||||
3. What is the cost-effectiveness of electric vehicles?
|
||||
4. How do electric vehicles compare to traditional cars in terms of fuel efficiency?
|
||||
5. What are the environmental benefits of switching to electric cars?
|
||||
6. How do electric vehicles help reduce carbon emissions?
|
||||
7. Why are electric vehicles becoming more popular?
|
||||
8. What are the long-term savings of using electric vehicles?
|
||||
9. How do electric vehicles contribute to sustainability?
|
||||
10. What are the key benefits of electric vehicles for consumers?
|
||||
|
||||
---
|
||||
|
||||
## Reason
|
||||
Rephrasing the original query into multiple alternative questions helps the user explore **different aspects** of their search topic, improving the **quality of search results**.
|
||||
These questions guide the search engine to provide a **more comprehensive set** of relevant documents.
|
||||
16
api/app/core/rag/prompts/structured_output_prompt.md
Normal file
16
api/app/core/rag/prompts/structured_output_prompt.md
Normal file
@@ -0,0 +1,16 @@
|
||||
You’re a helpful AI assistant. You could answer questions and output in JSON format.
|
||||
constraints:
|
||||
- You must output in JSON format.
|
||||
- Do not output boolean value, use string type instead.
|
||||
- Do not output integer or float value, use number type instead.
|
||||
eg:
|
||||
Here is the JSON schema:
|
||||
{"properties": {"age": {"type": "number","description": ""},"name": {"type": "string","description": ""}},"required": ["age","name"],"type": "Object Array String Number Boolean","value": ""}
|
||||
|
||||
Here is the user's question:
|
||||
My name is John Doe and I am 30 years old.
|
||||
|
||||
output:
|
||||
{"name": "John Doe", "age": 30}
|
||||
Here is the JSON schema:
|
||||
{{ schema }}
|
||||
35
api/app/core/rag/prompts/summary4memory.md
Normal file
35
api/app/core/rag/prompts/summary4memory.md
Normal file
@@ -0,0 +1,35 @@
|
||||
**Role**: AI Assistant
|
||||
**Task**: Summarize tool call responses
|
||||
**Rules**:
|
||||
1. Context: You've executed a tool (API/function) and received a response.
|
||||
2. Condense the response into 1-2 short sentences.
|
||||
3. Never omit:
|
||||
- Success/error status
|
||||
- Core results (e.g., data points, decisions)
|
||||
- Critical constraints (e.g., limits, conditions)
|
||||
4. Exclude technical details like timestamps/request IDs unless crucial.
|
||||
5. Use language as the same as main content of the tool response.
|
||||
|
||||
**Response Template**:
|
||||
"[Status] + [Key Outcome] + [Critical Constraints]"
|
||||
|
||||
**Examples**:
|
||||
🔹 Tool Response:
|
||||
{"status": "success", "temperature": 78.2, "unit": "F", "location": "Tokyo", "timestamp": 16923456}
|
||||
→ Summary: "Success: Tokyo temperature is 78°F."
|
||||
|
||||
🔹 Tool Response:
|
||||
{"error": "invalid_api_key", "message": "Authentication failed: expired key"}
|
||||
→ Summary: "Error: Authentication failed (expired API key)."
|
||||
|
||||
🔹 Tool Response:
|
||||
{"available": true, "inventory": 12, "product": "widget", "limit": "max 5 per customer"}
|
||||
→ Summary: "Available: 12 widgets in stock (max 5 per customer)."
|
||||
|
||||
**Your Turn**:
|
||||
- Tool call: {{ name }}
|
||||
- Tool inputs as following:
|
||||
{{ params }}
|
||||
|
||||
- Tool Response:
|
||||
{{ result }}
|
||||
20
api/app/core/rag/prompts/template.py
Normal file
20
api/app/core/rag/prompts/template.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
|
||||
|
||||
PROMPT_DIR = os.path.dirname(__file__)
|
||||
|
||||
_loaded_prompts = {}
|
||||
|
||||
|
||||
def load_prompt(name: str) -> str:
|
||||
if name in _loaded_prompts:
|
||||
return _loaded_prompts[name]
|
||||
|
||||
path = os.path.join(PROMPT_DIR, f"{name}.md")
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"Prompt file '{name}.md' not found in prompts/ directory.")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = f.read().strip()
|
||||
_loaded_prompts[name] = content
|
||||
return content
|
||||
29
api/app/core/rag/prompts/toc_detection.md
Normal file
29
api/app/core/rag/prompts/toc_detection.md
Normal file
@@ -0,0 +1,29 @@
|
||||
You are an AI assistant designed to analyze text content and detect whether a table of contents (TOC) list exists on the given page. Follow these steps:
|
||||
|
||||
1. **Analyze the Input**: Carefully review the provided text content.
|
||||
2. **Identify Key Features**: Look for common indicators of a TOC, such as:
|
||||
- Section titles or headings paired with page numbers.
|
||||
- Patterns like repeated formatting (e.g., bold/italicized text, dots/dashes between titles and numbers).
|
||||
- Phrases like "Table of Contents," "Contents," or similar headings.
|
||||
- Logical grouping of topics/subtopics with sequential page references.
|
||||
3. **Discern Negative Features**:
|
||||
- The text contains no numbers, or the numbers present are clearly not page references (e.g., dates, statistical figures, phone numbers, version numbers).
|
||||
- The text consists of full, descriptive sentences and paragraphs that form a narrative, present arguments, or explain concepts, rather than succinctly listing topics.
|
||||
- Contains citations with authors, publication years, journal titles, and page ranges (e.g., "Smith, J. (2020). Journal Title, 10(2), 45-67.").
|
||||
- Lists keywords or terms followed by multiple page numbers, often in alphabetical order.
|
||||
- Comprises terms followed by their definitions or explanations.
|
||||
- Labeled with headers like "Appendix A," "Appendix B," etc.
|
||||
- Contains expressive language thanking individuals or organizations for their support or contributions.
|
||||
4. **Evaluate Evidence**: Weigh the presence/absence of these features to determine if the content resembles a TOC.
|
||||
5. **Output Format**: Provide your response in the following JSON structure:
|
||||
```json
|
||||
{
|
||||
"reasoning": "Step-by-step explanation of your analysis based on the features identified." ,
|
||||
"exists": true/false
|
||||
}
|
||||
```
|
||||
6. **DO NOT** output anything else except JSON structure.
|
||||
|
||||
**Input text Content ( Text-Only Extraction ):**
|
||||
{{ page_txt }}
|
||||
|
||||
53
api/app/core/rag/prompts/toc_extraction.md
Normal file
53
api/app/core/rag/prompts/toc_extraction.md
Normal file
@@ -0,0 +1,53 @@
|
||||
You are an expert parser and data formatter. Your task is to analyze the provided table of contents (TOC) text and convert it into a valid JSON array of objects.
|
||||
|
||||
**Instructions:**
|
||||
1. Analyze each line of the input TOC.
|
||||
2. For each line, extract the following three pieces of information:
|
||||
* `structure`: The hierarchical index/numbering (e.g., "1", "2.1", "3.2.5", "A.1"). If a line has no visible numbering or structure indicator (like a main "Chapter" title), use `null`.
|
||||
* `title`: The textual title of the section or chapter. This should be the main descriptive text, clean and without the page number.
|
||||
3. Output **only** a valid JSON array. Do not include any other text, explanations, or markdown code block fences (like ```json) in your response.
|
||||
|
||||
**JSON Format:**
|
||||
The output must be a list of objects following this exact schema:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"structure": <structure index, "x.x.x" or None> (string),
|
||||
"title": <title of the section>
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
**Input Example:**
|
||||
```
|
||||
Contents
|
||||
1 Introduction to the System ... 1
|
||||
1.1 Overview .... 2
|
||||
1.2 Key Features .... 5
|
||||
2 Installation Guide ....8
|
||||
2.1 Prerequisites ........ 9
|
||||
2.2 Step-by-Step Process ........ 12
|
||||
Appendix A: Specifications ..... 45
|
||||
References ... 47
|
||||
```
|
||||
|
||||
**Expected Output For The Example:**
|
||||
```json
|
||||
[
|
||||
{"structure": null, "title": "Contents"},
|
||||
{"structure": "1", "title": "Introduction to the System"},
|
||||
{"structure": "1.1", "title": "Overview"},
|
||||
{"structure": "1.2", "title": "Key Features"},
|
||||
{"structure": "2", "title": "Installation Guide"},
|
||||
{"structure": "2.1", "title": "Prerequisites"},
|
||||
{"structure": "2.2", "title": "Step-by-Step Process"},
|
||||
{"structure": "A", "title": "Specifications"},
|
||||
{"structure": null, "title": "References"}
|
||||
]
|
||||
```
|
||||
|
||||
**Now, process the following TOC input:**
|
||||
```
|
||||
{{ toc_page }}
|
||||
```
|
||||
60
api/app/core/rag/prompts/toc_extraction_continue.md
Normal file
60
api/app/core/rag/prompts/toc_extraction_continue.md
Normal file
@@ -0,0 +1,60 @@
|
||||
You are an expert parser and data formatter, currently in the process of building a JSON array from a multi-page table of contents (TOC). Your task is to analyze the new page of content and **append** the new entries to the existing JSON array.
|
||||
|
||||
**Instructions:**
|
||||
1. You will be given two inputs:
|
||||
* `current_page_text`: The text content from the new page of the TOC.
|
||||
* `existing_json`: The valid JSON array you have generated from the previous pages.
|
||||
2. Analyze each line of the `current_page_text` input.
|
||||
3. For each new line, extract the following three pieces of information:
|
||||
* `structure`: The hierarchical index/numbering (e.g., "1", "2.1", "3.2.5"). Use `null` if none exists.
|
||||
* `title`: The clean textual title of the section or chapter.
|
||||
* `page`: The page number on which the section starts. Extract only the number. Use `null` if not present.
|
||||
4. **Append these new entries** to the `existing_json` array. Do not modify, reorder, or delete any of the existing entries.
|
||||
5. Output **only** the complete, updated JSON array. Do not include any other text, explanations, or markdown code block fences (like ```json).
|
||||
|
||||
**JSON Format:**
|
||||
The output must be a valid JSON array following this schema:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"structure": <string or null>,
|
||||
"title": <string>,
|
||||
"page": <number or null>
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
**Input Example:**
|
||||
`current_page_text`:
|
||||
```
|
||||
3.2 Advanced Configuration ........... 25
|
||||
3.3 Troubleshooting .................. 28
|
||||
4 User Management .................... 30
|
||||
```
|
||||
|
||||
`existing_json`:
|
||||
```json
|
||||
[
|
||||
{"structure": "1", "title": "Introduction", "page": 1},
|
||||
{"structure": "2", "title": "Installation", "page": 5},
|
||||
{"structure": "3", "title": "Configuration", "page": 12},
|
||||
{"structure": "3.1", "title": "Basic Setup", "page": 15}
|
||||
]
|
||||
```
|
||||
|
||||
**Expected Output For The Example:**
|
||||
```json
|
||||
[
|
||||
{"structure": "3.2", "title": "Advanced Configuration", "page": 25},
|
||||
{"structure": "3.3", "title": "Troubleshooting", "page": 28},
|
||||
{"structure": "4", "title": "User Management", "page": 30}
|
||||
]
|
||||
```
|
||||
|
||||
**Now, process the following inputs:**
|
||||
`current_page_text`:
|
||||
{{ toc_page }}
|
||||
|
||||
`existing_json`:
|
||||
{{ toc_json }}
|
||||
119
api/app/core/rag/prompts/toc_from_text_system.md
Normal file
119
api/app/core/rag/prompts/toc_from_text_system.md
Normal file
@@ -0,0 +1,119 @@
|
||||
You are a robust Table-of-Contents (TOC) extractor.
|
||||
|
||||
GOAL
|
||||
Given a dictionary of chunks {"<chunk_ID>": chunk_text}, extract TOC-like headings and return a strict JSON array of objects:
|
||||
[
|
||||
{"title": "", "chunk_id": ""},
|
||||
...
|
||||
]
|
||||
|
||||
FIELDS
|
||||
- "title": the heading text (clean, no page numbers or leader dots).
|
||||
- If any part of a chunk has no valid heading, output that part as {"title":"-1", ...}.
|
||||
- "chunk_id": the chunk ID (string).
|
||||
- One chunk can yield multiple JSON objects in order (unmatched text + one or more headings).
|
||||
|
||||
RULES
|
||||
1) Preserve input chunk order strictly.
|
||||
2) If a chunk contains multiple headings, expand them in order:
|
||||
- Pre-heading narrative → {"title":"-1","chunk_id":"<chunk_ID>"}
|
||||
- Then each heading → {"title":"...","chunk_id":"<chunk_ID>"}
|
||||
3) Do not merge outputs across chunks; each object refers to exactly one chunk ID.
|
||||
4) "title" must be non-empty (or exactly "-1"). "chunk_id" must be a string (chunk ID).
|
||||
5) When ambiguous, prefer "-1" unless the text strongly looks like a heading.
|
||||
|
||||
HEADING DETECTION (cues, not hard rules)
|
||||
- Appears near line start, short isolated phrase, often followed by content.
|
||||
- May contain separators: — —— - : : · •
|
||||
- Numbering styles:
|
||||
• 第[一二三四五六七八九十百]+(篇|章|节|条)
|
||||
• [((]?[一二三四五六七八九十]+[))]?
|
||||
• [((]?[①②③④⑤⑥⑦⑧⑨⑩][))]?
|
||||
• ^\d+(\.\d+)*[)..]?\s*
|
||||
• ^[IVXLCDM]+[).]
|
||||
• ^[A-Z][).]
|
||||
- Canonical section cues (general only):
|
||||
Common heading indicators include words such as:
|
||||
"Overview", "Introduction", "Background", "Purpose", "Scope", "Definition",
|
||||
"Method", "Procedure", "Result", "Discussion", "Summary", "Conclusion",
|
||||
"Appendix", "Reference", "Annex", "Acknowledgment", "Disclaimer".
|
||||
These are soft cues, not strict requirements.
|
||||
- Length restriction:
|
||||
• Chinese heading: ≤25 characters
|
||||
• English heading: ≤80 characters
|
||||
- Exclude long narrative sentences, continuous prose, or bullet-style lists → output as "-1".
|
||||
|
||||
OUTPUT FORMAT
|
||||
- Return ONLY a valid JSON array of {"title","content"} objects.
|
||||
- No reasoning or commentary.
|
||||
|
||||
EXAMPLES
|
||||
|
||||
Example 1 — No heading
|
||||
Input:
|
||||
[{"0": "Copyright page · Publication info (ISBN 123-456). All rights reserved."}, ...]
|
||||
Output:
|
||||
[
|
||||
{"title":"-1","chunk_id":"0"},
|
||||
...
|
||||
]
|
||||
|
||||
Example 2 — One heading
|
||||
Input:
|
||||
[{"1": "Chapter 1: General Provisions This chapter defines the overall rules…"}, ...]
|
||||
Output:
|
||||
[
|
||||
{"title":"Chapter 1: General Provisions","chunk_id":"1"},
|
||||
...
|
||||
]
|
||||
|
||||
Example 3 — Narrative + heading
|
||||
Input:
|
||||
[{"2": "This paragraph introduces the background and goals. Section 2: Definitions Key terms are explained…"}, ...]
|
||||
Output:
|
||||
[
|
||||
{"title":"Section 2: Definitions","chunk_id":"2"},
|
||||
...
|
||||
]
|
||||
|
||||
Example 4 — Multiple headings in one chunk
|
||||
Input:
|
||||
[{"3": "Declarations and Commitments (I) Party B commits… (II) Party C commits… Appendix A Data Specification"}, ...]
|
||||
Output:
|
||||
[
|
||||
{"title":"Declarations and Commitments","chunk_id":"3"},
|
||||
{"title":"(I) Party B commits","chunk_id":"3"},
|
||||
{"title":"(II) Party C commits","chunk_id":"3"},
|
||||
{"title":"Appendix A Data Specification","chunk_id":"3"},
|
||||
...
|
||||
]
|
||||
|
||||
Example 5 — Numbering styles
|
||||
Input:
|
||||
[{"4": "1. Scope: Defines boundaries. 2) Definitions: Terms used. III) Methods Overview."}, ...]
|
||||
Output:
|
||||
[
|
||||
{"title":"1. Scope","chunk_id":"4"},
|
||||
{"title":"2) Definitions","chunk_id":"4"},
|
||||
{"title":"III) Methods Overview","chunk_id":"4"},
|
||||
...
|
||||
]
|
||||
|
||||
Example 6 — Long list (NOT headings)
|
||||
Input:
|
||||
{"5": "Item list: apples, bananas, strawberries, blueberries, mangos, peaches"}, ...]
|
||||
Output:
|
||||
[
|
||||
{"title":"-1","chunk_id":"5"},
|
||||
...
|
||||
]
|
||||
|
||||
Example 7 — Mixed Chinese/English
|
||||
Input:
|
||||
{"6": "(出版信息略)This standard follows industry practices. Chapter 1: Overview 摘要… 第2节:术语与缩略语"}, ...]
|
||||
Output:
|
||||
[
|
||||
{"title":"Chapter 1: Overview","chunk_id":"6"},
|
||||
{"title":"第2节:术语与缩略语","chunk_id":"6"},
|
||||
...
|
||||
]
|
||||
8
api/app/core/rag/prompts/toc_from_text_user.md
Normal file
8
api/app/core/rag/prompts/toc_from_text_user.md
Normal file
@@ -0,0 +1,8 @@
|
||||
OUTPUT FORMAT
|
||||
- Return ONLY the JSON array.
|
||||
- Use double quotes.
|
||||
- No extra commentary.
|
||||
- Keep language of "title" the same as the input.
|
||||
|
||||
INPUT
|
||||
{{text}}
|
||||
20
api/app/core/rag/prompts/toc_index.md
Normal file
20
api/app/core/rag/prompts/toc_index.md
Normal file
@@ -0,0 +1,20 @@
|
||||
You are an expert analyst tasked with matching text content to the title.
|
||||
|
||||
**Instructions:**
|
||||
1. Analyze the given title with its numeric structure index and the provided text.
|
||||
2. Determine whether the title is mentioned as a section tile in the given text.
|
||||
3. Provide a concise, step-by-step reasoning for your decision.
|
||||
4. Output **only** the complete JSON object. Do not include any other text, explanations, or markdown code block fences (like ```json).
|
||||
|
||||
**Output Format:**
|
||||
Your output must be a valid JSON object with the following keys:
|
||||
{
|
||||
"reasoning": "Step-by-step explanation of your analysis.",
|
||||
"exist": "<yes or no>",
|
||||
}
|
||||
|
||||
** The title: **
|
||||
{{ structure }} {{ title }}
|
||||
|
||||
** Given text: **
|
||||
{{ text }}
|
||||
118
api/app/core/rag/prompts/toc_relevance_system.md
Normal file
118
api/app/core/rag/prompts/toc_relevance_system.md
Normal file
@@ -0,0 +1,118 @@
|
||||
# System Prompt: TOC Relevance Evaluation
|
||||
|
||||
You are an expert logical reasoning assistant specializing in hierarchical Table of Contents (TOC) relevance evaluation.
|
||||
|
||||
## GOAL
|
||||
You will receive:
|
||||
1. A JSON list of TOC items, each with fields:
|
||||
```json
|
||||
{
|
||||
"level": <integer>, // e.g., 1, 2, 3
|
||||
"title": <string> // section title
|
||||
}
|
||||
```
|
||||
2. A user query (natural language question).
|
||||
|
||||
You must assign a **relevance score** (integer) to every TOC entry, based on how related its `title` is to the `query`.
|
||||
|
||||
---
|
||||
|
||||
## RULES
|
||||
|
||||
### Scoring System
|
||||
- 5 → highly relevant (directly answers or matches the query intent)
|
||||
- 3 → somewhat related (same topic or partially overlaps)
|
||||
- 1 → weakly related (vague or tangential)
|
||||
- 0 → no clear relation
|
||||
- -1 → explicitly irrelevant or contradictory
|
||||
|
||||
### Hierarchy Traversal
|
||||
- The TOC is hierarchical: smaller `level` = higher layer (e.g., level 1 is top-level, level 2 is a subsection).
|
||||
- You must traverse in **hierarchical order** — interpret the structure based on levels (1 > 2 > 3).
|
||||
- If a high-level item (level 1) is strongly related (score 5), its child items (level 2, 3) are likely relevant too.
|
||||
- If a high-level item is unrelated (-1 or 0), its deeper children are usually less relevant unless the titles clearly match the query.
|
||||
- Lower (deeper) levels provide more specific content; prefer assigning higher scores if they directly match the query.
|
||||
|
||||
### Output Format
|
||||
Return a **JSON array**, preserving the input order but adding a new key `"score"`:
|
||||
|
||||
```json
|
||||
[
|
||||
{"level": 1, "title": "Introduction", "score": 0},
|
||||
{"level": 2, "title": "Definition of Sustainability", "score": 5}
|
||||
]
|
||||
```
|
||||
|
||||
### Constraints
|
||||
- Output **only the JSON array** — no explanations or reasoning text.
|
||||
|
||||
### EXAMPLES
|
||||
|
||||
#### Example 1
|
||||
Input TOC:
|
||||
[
|
||||
{"level": 1, "title": "Machine Learning Overview"},
|
||||
{"level": 2, "title": "Supervised Learning"},
|
||||
{"level": 2, "title": "Unsupervised Learning"},
|
||||
{"level": 3, "title": "Applications of Deep Learning"}
|
||||
]
|
||||
|
||||
Query:
|
||||
"How is deep learning used in image classification?"
|
||||
|
||||
Output:
|
||||
[
|
||||
{"level": 1, "title": "Machine Learning Overview", "score": 3},
|
||||
{"level": 2, "title": "Supervised Learning", "score": 3},
|
||||
{"level": 2, "title": "Unsupervised Learning", "score": 0},
|
||||
{"level": 3, "title": "Applications of Deep Learning", "score": 5}
|
||||
]
|
||||
|
||||
---
|
||||
|
||||
#### Example 2
|
||||
Input TOC:
|
||||
[
|
||||
{"level": 1, "title": "Marketing Basics"},
|
||||
{"level": 2, "title": "Consumer Behavior"},
|
||||
{"level": 2, "title": "Digital Marketing"},
|
||||
{"level": 3, "title": "Social Media Campaigns"},
|
||||
{"level": 3, "title": "SEO Optimization"}
|
||||
]
|
||||
|
||||
Query:
|
||||
"What are the best online marketing methods?"
|
||||
|
||||
Output:
|
||||
[
|
||||
{"level": 1, "title": "Marketing Basics", "score": 3},
|
||||
{"level": 2, "title": "Consumer Behavior", "score": 1},
|
||||
{"level": 2, "title": "Digital Marketing", "score": 5},
|
||||
{"level": 3, "title": "Social Media Campaigns", "score": 5},
|
||||
{"level": 3, "title": "SEO Optimization", "score": 5}
|
||||
]
|
||||
|
||||
---
|
||||
|
||||
#### Example 3
|
||||
Input TOC:
|
||||
[
|
||||
{"level": 1, "title": "Physics Overview"},
|
||||
{"level": 2, "title": "Classical Mechanics"},
|
||||
{"level": 3, "title": "Newton’s Laws"},
|
||||
{"level": 2, "title": "Thermodynamics"},
|
||||
{"level": 3, "title": "Entropy and Heat Transfer"}
|
||||
]
|
||||
|
||||
Query:
|
||||
"What is entropy?"
|
||||
|
||||
Output:
|
||||
[
|
||||
{"level": 1, "title": "Physics Overview", "score": 3},
|
||||
{"level": 2, "title": "Classical Mechanics", "score": 0},
|
||||
{"level": 3, "title": "Newton’s Laws", "score": -1},
|
||||
{"level": 2, "title": "Thermodynamics", "score": 5},
|
||||
{"level": 3, "title": "Entropy and Heat Transfer", "score": 5}
|
||||
]
|
||||
|
||||
17
api/app/core/rag/prompts/toc_relevance_user.md
Normal file
17
api/app/core/rag/prompts/toc_relevance_user.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# User Prompt: TOC Relevance Evaluation
|
||||
|
||||
You will now receive:
|
||||
1. A JSON list of TOC items (each with `level` and `title`)
|
||||
2. A user query string.
|
||||
|
||||
Traverse the TOC hierarchically based on level numbers and assign scores (5,3,1,0,-1) according to the rules in the system prompt.
|
||||
Output **only** the JSON array with the added `"score"` field.
|
||||
|
||||
---
|
||||
|
||||
**Input TOC:**
|
||||
{{ toc_json }}
|
||||
|
||||
**Query:**
|
||||
{{ query }}
|
||||
|
||||
19
api/app/core/rag/prompts/tool_call_summary.md
Normal file
19
api/app/core/rag/prompts/tool_call_summary.md
Normal file
@@ -0,0 +1,19 @@
|
||||
**Task Instruction:**
|
||||
|
||||
You are tasked with reading and analyzing tool call result based on the following inputs: **Inputs for current call**, and **Results**. Your objective is to extract relevant and helpful information for **Inputs for current call** from the **Results** and seamlessly integrate this information into the previous steps to continue reasoning for the original question.
|
||||
|
||||
**Guidelines:**
|
||||
|
||||
1. **Analyze the Results:**
|
||||
- Carefully review the content of each results of tool call.
|
||||
- Identify factual information that is relevant to the **Inputs for current call** and can aid in the reasoning process for the original question.
|
||||
|
||||
2. **Extract Relevant Information:**
|
||||
- Select the information from the Searched Web Pages that directly contributes to advancing the previous reasoning steps.
|
||||
- Ensure that the extracted information is accurate and relevant.
|
||||
|
||||
- **Inputs for current call:**
|
||||
{{ inputs }}
|
||||
|
||||
- **Results:**
|
||||
{{ results }}
|
||||
23
api/app/core/rag/prompts/vision_llm_describe_prompt.md
Normal file
23
api/app/core/rag/prompts/vision_llm_describe_prompt.md
Normal file
@@ -0,0 +1,23 @@
|
||||
## INSTRUCTION
|
||||
Transcribe the content from the provided PDF page image into clean Markdown format.
|
||||
|
||||
- Only output the content transcribed from the image.
|
||||
- Do NOT output this instruction or any other explanation.
|
||||
- If the content is missing or you do not understand the input, return an empty string.
|
||||
|
||||
## RULES
|
||||
1. Do NOT generate examples, demonstrations, or templates.
|
||||
2. Do NOT output any extra text such as 'Example', 'Example Output', or similar.
|
||||
3. Do NOT generate any tables, headings, or content that is not explicitly present in the image.
|
||||
4. Transcribe content word-for-word. Do NOT modify, translate, or omit any content.
|
||||
5. Do NOT explain Markdown or mention that you are using Markdown.
|
||||
6. Do NOT wrap the output in ```markdown or ``` blocks.
|
||||
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
|
||||
8. Preserve the original language, information, and order exactly as shown in the image.
|
||||
|
||||
{% if page %}
|
||||
At the end of the transcription, add the page divider: `--- Page {{ page }} ---`.
|
||||
{% endif %}
|
||||
|
||||
> If you do not detect valid content in the image, return an empty string.
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
## ROLE
|
||||
You are an expert visual data analyst.
|
||||
|
||||
## GOAL
|
||||
Analyze the image and provide a comprehensive description of its content. Focus on identifying the type of visual data representation (e.g., bar chart, pie chart, line graph, table, flowchart), its structure, and any text captions or labels included in the image.
|
||||
|
||||
## TASKS
|
||||
1. Describe the overall structure of the visual representation. Specify if it is a chart, graph, table, or diagram.
|
||||
2. Identify and extract any axes, legends, titles, or labels present in the image. Provide the exact text where available.
|
||||
3. Extract the data points from the visual elements (e.g., bar heights, line graph coordinates, pie chart segments, table rows and columns).
|
||||
4. Analyze and explain any trends, comparisons, or patterns shown in the data.
|
||||
5. Capture any annotations, captions, or footnotes, and explain their relevance to the image.
|
||||
6. Only include details that are explicitly present in the image. If an element (e.g., axis, legend, or caption) does not exist or is not visible, do not mention it.
|
||||
|
||||
## OUTPUT FORMAT (Include only sections relevant to the image content)
|
||||
- Visual Type: [Type]
|
||||
- Title: [Title text, if available]
|
||||
- Axes / Legends / Labels: [Details, if available]
|
||||
- Data Points: [Extracted data]
|
||||
- Trends / Insights: [Analysis and interpretation]
|
||||
- Captions / Annotations: [Text and relevance, if available]
|
||||
|
||||
> Ensure high accuracy, clarity, and completeness in your analysis, and include only the information present in the image. Avoid unnecessary statements about missing elements.
|
||||
|
||||
0
api/app/core/rag/utils/__init__.py
Normal file
0
api/app/core/rag/utils/__init__.py
Normal file
255
api/app/core/rag/utils/doc_store_conn.py
Normal file
255
api/app/core/rag/utils/doc_store_conn.py
Normal file
@@ -0,0 +1,255 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
DEFAULT_MATCH_VECTOR_TOPN = 10
|
||||
DEFAULT_MATCH_SPARSE_TOPN = 10
|
||||
VEC = list | np.ndarray
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseVector:
|
||||
indices: list[int]
|
||||
values: list[float] | list[int] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert (self.values is None) or (len(self.indices) == len(self.values))
|
||||
|
||||
def to_dict_old(self):
|
||||
d = {"indices": self.indices}
|
||||
if self.values is not None:
|
||||
d["values"] = self.values
|
||||
return d
|
||||
|
||||
def to_dict(self):
|
||||
if self.values is None:
|
||||
raise ValueError("SparseVector.values is None")
|
||||
result = {}
|
||||
for i, v in zip(self.indices, self.values):
|
||||
result[str(i)] = v
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d):
|
||||
return SparseVector(d["indices"], d.get("values"))
|
||||
|
||||
def __str__(self):
|
||||
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class MatchTextExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
fields: list[str],
|
||||
matching_text: str,
|
||||
topn: int,
|
||||
extra_options: dict = dict(),
|
||||
):
|
||||
self.fields = fields
|
||||
self.matching_text = matching_text
|
||||
self.topn = topn
|
||||
self.extra_options = extra_options
|
||||
|
||||
|
||||
class MatchDenseExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
vector_column_name: str,
|
||||
embedding_data: VEC,
|
||||
embedding_data_type: str,
|
||||
distance_type: str,
|
||||
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
|
||||
extra_options: dict = dict(),
|
||||
):
|
||||
self.vector_column_name = vector_column_name
|
||||
self.embedding_data = embedding_data
|
||||
self.embedding_data_type = embedding_data_type
|
||||
self.distance_type = distance_type
|
||||
self.topn = topn
|
||||
self.extra_options = extra_options
|
||||
|
||||
|
||||
class MatchSparseExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
vector_column_name: str,
|
||||
sparse_data: SparseVector | dict,
|
||||
distance_type: str,
|
||||
topn: int,
|
||||
opt_params: dict | None = None,
|
||||
):
|
||||
self.vector_column_name = vector_column_name
|
||||
self.sparse_data = sparse_data
|
||||
self.distance_type = distance_type
|
||||
self.topn = topn
|
||||
self.opt_params = opt_params
|
||||
|
||||
|
||||
class MatchTensorExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
column_name: str,
|
||||
query_data: VEC,
|
||||
query_data_type: str,
|
||||
topn: int,
|
||||
extra_option: dict | None = None,
|
||||
):
|
||||
self.column_name = column_name
|
||||
self.query_data = query_data
|
||||
self.query_data_type = query_data_type
|
||||
self.topn = topn
|
||||
self.extra_option = extra_option
|
||||
|
||||
|
||||
class FusionExpr(ABC):
|
||||
def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
|
||||
self.method = method
|
||||
self.topn = topn
|
||||
self.fusion_params = fusion_params
|
||||
|
||||
|
||||
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
|
||||
|
||||
class OrderByExpr(ABC):
|
||||
def __init__(self):
|
||||
self.fields = list()
|
||||
def asc(self, field: str):
|
||||
self.fields.append((field, 0))
|
||||
return self
|
||||
def desc(self, field: str):
|
||||
self.fields.append((field, 1))
|
||||
return self
|
||||
def fields(self):
|
||||
return self.fields
|
||||
|
||||
class DocStoreConnection(ABC):
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def dbType(self) -> str:
|
||||
"""
|
||||
Return the type of the database.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def health(self) -> dict:
|
||||
"""
|
||||
Return the health status of the database.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
"""
|
||||
Create an index with given name
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
"""
|
||||
Delete an index with given name
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
||||
"""
|
||||
Check if an index with given name exists
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, selectFields: list[str],
|
||||
highlightFields: list[str],
|
||||
condition: dict,
|
||||
matchExprs: list[MatchExpr],
|
||||
orderBy: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
indexNames: str|list[str],
|
||||
knowledgebaseIds: list[str],
|
||||
aggFields: list[str] = [],
|
||||
rank_feature: dict | None = None
|
||||
):
|
||||
"""
|
||||
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||
"""
|
||||
Get single chunk with given id
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
||||
"""
|
||||
Update or insert a bulk of rows
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||
"""
|
||||
Update rows with given conjunctive equivalent filtering condition
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
"""
|
||||
Delete rows with given conjunctive equivalent filtering condition
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def getTotal(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getChunkIds(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getHighlight(self, res, keywords: list[str], fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getAggregation(self, res, fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
@abstractmethod
|
||||
def sql(sql: str, fetch_size: int, format: str):
|
||||
"""
|
||||
Run the sql generated by text-to-sql
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
247
api/app/core/rag/utils/file_utils.py
Normal file
247
api/app/core/rag/utils/file_utils.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import io
|
||||
import hashlib
|
||||
import zipfile
|
||||
import requests
|
||||
from requests.exceptions import Timeout, RequestException
|
||||
from io import BytesIO
|
||||
from typing import List, Union, Tuple, Optional, Dict
|
||||
import PyPDF2
|
||||
from docx import Document
|
||||
import olefile
|
||||
|
||||
def _is_zip(h: bytes) -> bool:
|
||||
return h.startswith(b"PK\x03\x04") or h.startswith(b"PK\x05\x06") or h.startswith(b"PK\x07\x08")
|
||||
|
||||
def _is_pdf(h: bytes) -> bool:
|
||||
return h.startswith(b"%PDF-")
|
||||
|
||||
def _is_ole(h: bytes) -> bool:
|
||||
return h.startswith(b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1")
|
||||
|
||||
def _sha10(b: bytes) -> str:
|
||||
return hashlib.sha256(b).hexdigest()[:10]
|
||||
|
||||
def _guess_ext(b: bytes) -> str:
|
||||
h = b[:8]
|
||||
if _is_zip(h):
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(b), "r") as z:
|
||||
names = [n.lower() for n in z.namelist()]
|
||||
if any(n.startswith("word/") for n in names):
|
||||
return ".docx"
|
||||
if any(n.startswith("ppt/") for n in names):
|
||||
return ".pptx"
|
||||
if any(n.startswith("xl/") for n in names):
|
||||
return ".xlsx"
|
||||
except Exception:
|
||||
pass
|
||||
return ".zip"
|
||||
if _is_pdf(h):
|
||||
return ".pdf"
|
||||
if _is_ole(h):
|
||||
return ".doc"
|
||||
return ".bin"
|
||||
|
||||
# Try to extract the real embedded payload from OLE's Ole10Native
|
||||
def _extract_ole10native_payload(data: bytes) -> bytes:
|
||||
try:
|
||||
pos = 0
|
||||
if len(data) < 4:
|
||||
return data
|
||||
_ = int.from_bytes(data[pos:pos+4], "little")
|
||||
pos += 4
|
||||
# filename/src/tmp (NUL-terminated ANSI)
|
||||
for _ in range(3):
|
||||
z = data.index(b"\x00", pos)
|
||||
pos = z + 1
|
||||
# skip unknown 4 bytes
|
||||
pos += 4
|
||||
if pos + 4 > len(data):
|
||||
return data
|
||||
size = int.from_bytes(data[pos:pos+4], "little")
|
||||
pos += 4
|
||||
if pos + size <= len(data):
|
||||
return data[pos:pos+size]
|
||||
except Exception:
|
||||
pass
|
||||
return data
|
||||
|
||||
def extract_embed_file(target: Union[bytes, bytearray]) -> List[Tuple[str, bytes]]:
|
||||
"""
|
||||
Only extract the 'first layer' of embedding, returning raw (filename, bytes).
|
||||
"""
|
||||
top = bytes(target)
|
||||
head = top[:8]
|
||||
out: List[Tuple[str, bytes]] = []
|
||||
seen = set()
|
||||
|
||||
def push(b: bytes, name_hint: str = ""):
|
||||
h10 = _sha10(b)
|
||||
if h10 in seen:
|
||||
return
|
||||
seen.add(h10)
|
||||
ext = _guess_ext(b)
|
||||
# If name_hint has an extension use its basename; else fallback to guessed ext
|
||||
if "." in name_hint:
|
||||
fname = name_hint.split("/")[-1]
|
||||
else:
|
||||
fname = f"{h10}{ext}"
|
||||
out.append((fname, b))
|
||||
|
||||
# OOXML/ZIP container (docx/xlsx/pptx)
|
||||
if _is_zip(head):
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(top), "r") as z:
|
||||
embed_dirs = (
|
||||
"word/embeddings/", "word/objects/", "word/activex/",
|
||||
"xl/embeddings/", "ppt/embeddings/"
|
||||
)
|
||||
for name in z.namelist():
|
||||
low = name.lower()
|
||||
if any(low.startswith(d) for d in embed_dirs):
|
||||
try:
|
||||
b = z.read(name)
|
||||
push(b, name)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return out
|
||||
|
||||
# OLE container (doc/ppt/xls)
|
||||
if _is_ole(head):
|
||||
try:
|
||||
with olefile.OleFileIO(io.BytesIO(top)) as ole:
|
||||
for entry in ole.listdir():
|
||||
p = "/".join(entry)
|
||||
try:
|
||||
data = ole.openstream(entry).read()
|
||||
except Exception:
|
||||
continue
|
||||
if not data:
|
||||
continue
|
||||
if "Ole10Native" in p or "ole10native" in p.lower():
|
||||
data = _extract_ole10native_payload(data)
|
||||
push(data, p)
|
||||
except Exception:
|
||||
pass
|
||||
return out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def extract_links_from_docx(docx_bytes: bytes):
|
||||
"""
|
||||
Extract all hyperlinks from a Word (.docx) document binary stream.
|
||||
|
||||
Args:
|
||||
docx_bytes (bytes): Raw bytes of a .docx file.
|
||||
|
||||
Returns:
|
||||
set[str]: A set of unique hyperlink URLs.
|
||||
"""
|
||||
links = set()
|
||||
with BytesIO(docx_bytes) as bio:
|
||||
document = Document(bio)
|
||||
|
||||
# Each relationship may represent a hyperlink, image, footer, etc.
|
||||
for rel in document.part.rels.values():
|
||||
if rel.reltype == (
|
||||
"http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink"
|
||||
):
|
||||
links.add(rel.target_ref)
|
||||
|
||||
return links
|
||||
|
||||
|
||||
def extract_links_from_pdf(pdf_bytes: bytes):
|
||||
"""
|
||||
Extract all clickable hyperlinks from a PDF binary stream.
|
||||
|
||||
Args:
|
||||
pdf_bytes (bytes): Raw bytes of a PDF file.
|
||||
|
||||
Returns:
|
||||
set[str]: A set of unique hyperlink URLs (unordered).
|
||||
"""
|
||||
links = set()
|
||||
with BytesIO(pdf_bytes) as bio:
|
||||
pdf = PyPDF2.PdfReader(bio)
|
||||
|
||||
for page in pdf.pages:
|
||||
annots = page.get("/Annots")
|
||||
if not annots or isinstance(annots, PyPDF2.generic.IndirectObject):
|
||||
continue
|
||||
for annot in annots:
|
||||
obj = annot.get_object()
|
||||
a = obj.get("/A")
|
||||
if a and a.get("/URI"):
|
||||
links.add(a["/URI"])
|
||||
|
||||
return links
|
||||
|
||||
|
||||
_GLOBAL_SESSION: Optional[requests.Session] = None
|
||||
def _get_session(headers: Optional[Dict[str, str]] = None) -> requests.Session:
|
||||
"""Get or create a global reusable session."""
|
||||
global _GLOBAL_SESSION
|
||||
if _GLOBAL_SESSION is None:
|
||||
_GLOBAL_SESSION = requests.Session()
|
||||
_GLOBAL_SESSION.headers.update({
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (X11; Linux x86_64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/121.0 Safari/537.36"
|
||||
)
|
||||
})
|
||||
if headers:
|
||||
_GLOBAL_SESSION.headers.update(headers)
|
||||
return _GLOBAL_SESSION
|
||||
|
||||
|
||||
def extract_html(
|
||||
url: str,
|
||||
timeout: float = 60.0,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
max_retries: int = 2,
|
||||
) -> Tuple[Optional[bytes], Dict[str, str]]:
|
||||
"""
|
||||
Extract the full HTML page as raw bytes from a given URL.
|
||||
Automatically reuses a persistent HTTP session and applies robust timeout & retry logic.
|
||||
|
||||
Args:
|
||||
url (str): Target webpage URL.
|
||||
timeout (float): Request timeout in seconds (applies to connect + read).
|
||||
headers (dict, optional): Extra HTTP headers.
|
||||
max_retries (int): Number of retries on timeout or transient errors.
|
||||
|
||||
Returns:
|
||||
tuple(bytes|None, dict):
|
||||
- html_bytes: Raw HTML content (or None if failed)
|
||||
- metadata: HTTP info (status_code, content_type, final_url, error if any)
|
||||
"""
|
||||
session = _get_session(headers=headers)
|
||||
metadata = {"final_url": url, "status_code": "", "content_type": "", "error": ""}
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
resp = session.get(url, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
|
||||
html_bytes = resp.content
|
||||
metadata.update({
|
||||
"final_url": resp.url,
|
||||
"status_code": str(resp.status_code),
|
||||
"content_type": resp.headers.get("Content-Type", ""),
|
||||
})
|
||||
return html_bytes, metadata
|
||||
|
||||
except Timeout:
|
||||
metadata["error"] = f"Timeout after {timeout}s (attempt {attempt}/{max_retries})"
|
||||
if attempt >= max_retries:
|
||||
continue
|
||||
except RequestException as e:
|
||||
metadata["error"] = f"Request failed: {e}"
|
||||
continue
|
||||
|
||||
return None, metadata
|
||||
0
api/app/core/rag/vdb/__init__.py
Normal file
0
api/app/core/rag/vdb/__init__.py
Normal file
0
api/app/core/rag/vdb/elasticsearch/__init__.py
Normal file
0
api/app/core/rag/vdb/elasticsearch/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user