feat: Add base project structure with API and web components
This commit is contained in:
857
api/app/core/rag/nlp/__init__.py
Normal file
857
api/app/core/rag/nlp/__init__.py
Normal file
@@ -0,0 +1,857 @@
|
||||
import logging
|
||||
import random
|
||||
from collections import Counter
|
||||
|
||||
from app.core.rag.common.token_utils import num_tokens_from_string
|
||||
from . import rag_tokenizer
|
||||
import re
|
||||
import copy
|
||||
import roman_numbers as r
|
||||
from word2number import w2n
|
||||
from cn2an import cn2an
|
||||
from PIL import Image
|
||||
|
||||
import chardet
|
||||
|
||||
all_codecs = [
|
||||
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
|
||||
'cp037', 'cp273', 'cp424', 'cp437',
|
||||
'cp500', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp856', 'cp857',
|
||||
'cp858', 'cp860', 'cp861', 'cp862', 'cp863', 'cp864', 'cp865', 'cp866', 'cp869',
|
||||
'cp874', 'cp875', 'cp932', 'cp949', 'cp950', 'cp1006', 'cp1026', 'cp1125',
|
||||
'cp1140', 'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', 'cp1256',
|
||||
'cp1257', 'cp1258', 'euc_jp', 'euc_jis_2004', 'euc_jisx0213', 'euc_kr',
|
||||
'gb18030', 'hz', 'iso2022_jp', 'iso2022_jp_1', 'iso2022_jp_2',
|
||||
'iso2022_jp_2004', 'iso2022_jp_3', 'iso2022_jp_ext', 'iso2022_kr', 'latin_1',
|
||||
'iso8859_2', 'iso8859_3', 'iso8859_4', 'iso8859_5', 'iso8859_6', 'iso8859_7',
|
||||
'iso8859_8', 'iso8859_9', 'iso8859_10', 'iso8859_11', 'iso8859_13',
|
||||
'iso8859_14', 'iso8859_15', 'iso8859_16', 'johab', 'koi8_r', 'koi8_t', 'koi8_u',
|
||||
'kz1048', 'mac_cyrillic', 'mac_greek', 'mac_iceland', 'mac_latin2', 'mac_roman',
|
||||
'mac_turkish', 'ptcp154', 'shift_jis', 'shift_jis_2004', 'shift_jisx0213',
|
||||
'utf_32', 'utf_32_be', 'utf_32_le', 'utf_16_be', 'utf_16_le', 'utf_7', 'windows-1250', 'windows-1251',
|
||||
'windows-1252', 'windows-1253', 'windows-1254', 'windows-1255', 'windows-1256',
|
||||
'windows-1257', 'windows-1258', 'latin-2'
|
||||
]
|
||||
|
||||
|
||||
def find_codec(blob):
|
||||
detected = chardet.detect(blob[:1024])
|
||||
if detected['confidence'] > 0.5:
|
||||
if detected['encoding'] == "ascii":
|
||||
return "utf-8"
|
||||
|
||||
for c in all_codecs:
|
||||
try:
|
||||
blob[:1024].decode(c)
|
||||
return c
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
blob.decode(c)
|
||||
return c
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "utf-8"
|
||||
|
||||
|
||||
QUESTION_PATTERN = [
|
||||
r"第([零一二三四五六七八九十百0-9]+)问",
|
||||
r"第([零一二三四五六七八九十百0-9]+)条",
|
||||
r"[\((]([零一二三四五六七八九十百]+)[\))]",
|
||||
r"第([0-9]+)问",
|
||||
r"第([0-9]+)条",
|
||||
r"([0-9]{1,2})[\. 、]",
|
||||
r"([零一二三四五六七八九十百]+)[ 、]",
|
||||
r"[\((]([0-9]{1,2})[\))]",
|
||||
r"QUESTION (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
|
||||
r"QUESTION (I+V?|VI*|XI|IX|X)",
|
||||
r"QUESTION ([0-9]+)",
|
||||
]
|
||||
|
||||
|
||||
def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list):
|
||||
section, last_section = box['text'], last_box['text']
|
||||
q_reg = r'(\w|\W)*?(?:?|\?|\n|$)+'
|
||||
full_reg = reg + q_reg
|
||||
has_bull = re.match(full_reg, section)
|
||||
index_str = None
|
||||
if has_bull:
|
||||
if 'x0' not in last_box:
|
||||
last_box['x0'] = box['x0']
|
||||
if 'top' not in last_box:
|
||||
last_box['top'] = box['top']
|
||||
if last_bull and box['x0'] - last_box['x0'] > 10:
|
||||
return None, last_index
|
||||
if not last_bull and box['x0'] >= last_box['x0'] and box['top'] - last_box['top'] < 20:
|
||||
return None, last_index
|
||||
avg_bull_x0 = 0
|
||||
if bull_x0_list:
|
||||
avg_bull_x0 = sum(bull_x0_list) / len(bull_x0_list)
|
||||
else:
|
||||
avg_bull_x0 = box['x0']
|
||||
if box['x0'] - avg_bull_x0 > 10:
|
||||
return None, last_index
|
||||
index_str = has_bull.group(1)
|
||||
index = index_int(index_str)
|
||||
if last_section[-1] == ':' or last_section[-1] == ':':
|
||||
return None, last_index
|
||||
if not last_index or index >= last_index:
|
||||
bull_x0_list.append(box['x0'])
|
||||
return has_bull, index
|
||||
if section[-1] == '?' or section[-1] == '?':
|
||||
bull_x0_list.append(box['x0'])
|
||||
return has_bull, index
|
||||
if box['layout_type'] == 'title':
|
||||
bull_x0_list.append(box['x0'])
|
||||
return has_bull, index
|
||||
pure_section = section.lstrip(re.match(reg, section).group()).lower()
|
||||
ask_reg = r'(what|when|where|how|why|which|who|whose|为什么|为啥|哪)'
|
||||
if re.match(ask_reg, pure_section):
|
||||
bull_x0_list.append(box['x0'])
|
||||
return has_bull, index
|
||||
return None, last_index
|
||||
|
||||
|
||||
def index_int(index_str):
|
||||
res = -1
|
||||
try:
|
||||
res = int(index_str)
|
||||
except ValueError:
|
||||
try:
|
||||
res = w2n.word_to_num(index_str)
|
||||
except ValueError:
|
||||
try:
|
||||
res = cn2an(index_str)
|
||||
except ValueError:
|
||||
try:
|
||||
res = r.number(index_str)
|
||||
except ValueError:
|
||||
return -1
|
||||
return res
|
||||
|
||||
|
||||
def qbullets_category(sections):
|
||||
global QUESTION_PATTERN
|
||||
hits = [0] * len(QUESTION_PATTERN)
|
||||
for i, pro in enumerate(QUESTION_PATTERN):
|
||||
for sec in sections:
|
||||
if re.match(pro, sec) and not not_bullet(sec):
|
||||
hits[i] += 1
|
||||
break
|
||||
maxium = 0
|
||||
res = -1
|
||||
for i, h in enumerate(hits):
|
||||
if h <= maxium:
|
||||
continue
|
||||
res = i
|
||||
maxium = h
|
||||
return res, QUESTION_PATTERN[res]
|
||||
|
||||
|
||||
BULLET_PATTERN = [[
|
||||
r"第[零一二三四五六七八九十百0-9]+(分?编|部分)",
|
||||
r"第[零一二三四五六七八九十百0-9]+章",
|
||||
r"第[零一二三四五六七八九十百0-9]+节",
|
||||
r"第[零一二三四五六七八九十百0-9]+条",
|
||||
r"[\((][零一二三四五六七八九十百]+[\))]",
|
||||
], [
|
||||
r"第[0-9]+章",
|
||||
r"第[0-9]+节",
|
||||
r"[0-9]{,2}[\. 、]",
|
||||
r"[0-9]{,2}\.[0-9]{,2}[^a-zA-Z/%~-]",
|
||||
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
|
||||
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
|
||||
], [
|
||||
r"第[零一二三四五六七八九十百0-9]+章",
|
||||
r"第[零一二三四五六七八九十百0-9]+节",
|
||||
r"[零一二三四五六七八九十百]+[ 、]",
|
||||
r"[\((][零一二三四五六七八九十百]+[\))]",
|
||||
r"[\((][0-9]{,2}[\))]",
|
||||
], [
|
||||
r"PART (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
|
||||
r"Chapter (I+V?|VI*|XI|IX|X)",
|
||||
r"Section [0-9]+",
|
||||
r"Article [0-9]+"
|
||||
], [
|
||||
r"^#[^#]",
|
||||
r"^##[^#]",
|
||||
r"^###.*",
|
||||
r"^####.*",
|
||||
r"^#####.*",
|
||||
r"^######.*",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def random_choices(arr, k):
|
||||
k = min(len(arr), k)
|
||||
return random.choices(arr, k=k)
|
||||
|
||||
|
||||
def not_bullet(line):
|
||||
patt = [
|
||||
r"0", r"[0-9]+ +[0-9~个只-]", r"[0-9]+\.{2,}"
|
||||
]
|
||||
return any([re.match(r, line) for r in patt])
|
||||
|
||||
|
||||
def bullets_category(sections):
|
||||
global BULLET_PATTERN
|
||||
hits = [0] * len(BULLET_PATTERN)
|
||||
for i, pro in enumerate(BULLET_PATTERN):
|
||||
for sec in sections:
|
||||
sec = sec.strip()
|
||||
for p in pro:
|
||||
if re.match(p, sec) and not not_bullet(sec):
|
||||
hits[i] += 1
|
||||
break
|
||||
maxium = 0
|
||||
res = -1
|
||||
for i, h in enumerate(hits):
|
||||
if h <= maxium:
|
||||
continue
|
||||
res = i
|
||||
maxium = h
|
||||
return res
|
||||
|
||||
|
||||
def is_english(texts):
|
||||
if not texts:
|
||||
return False
|
||||
|
||||
pattern = re.compile(r"[`a-zA-Z0-9\s.,':;/\"?<>!\(\)\-]")
|
||||
|
||||
if isinstance(texts, str):
|
||||
texts = list(texts)
|
||||
elif isinstance(texts, list):
|
||||
texts = [t for t in texts if isinstance(t, str) and t.strip()]
|
||||
else:
|
||||
return False
|
||||
|
||||
if not texts:
|
||||
return False
|
||||
|
||||
eng = sum(1 for t in texts if pattern.fullmatch(t.strip()))
|
||||
return (eng / len(texts)) > 0.8
|
||||
|
||||
|
||||
def is_chinese(text):
|
||||
if not text:
|
||||
return False
|
||||
chinese = 0
|
||||
for ch in text:
|
||||
if '\u4e00' <= ch <= '\u9fff':
|
||||
chinese += 1
|
||||
if chinese / len(text) > 0.2:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def tokenize(d, t, eng):
|
||||
d["content_with_weight"] = t
|
||||
t = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", t)
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(t)
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
|
||||
|
||||
def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
|
||||
res = []
|
||||
# wrap up as es documents
|
||||
for ii, ck in enumerate(chunks):
|
||||
if len(ck.strip()) == 0:
|
||||
continue
|
||||
logging.debug("-- {}".format(ck))
|
||||
d = copy.deepcopy(doc)
|
||||
if pdf_parser:
|
||||
try:
|
||||
d["image"], poss = pdf_parser.crop(ck, need_position=True)
|
||||
add_positions(d, poss)
|
||||
ck = pdf_parser.remove_tag(ck)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
add_positions(d, [[ii]*5])
|
||||
tokenize(d, ck, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def tokenize_chunks_with_images(chunks, doc, eng, images):
|
||||
res = []
|
||||
# wrap up as es documents
|
||||
for ii, (ck, image) in enumerate(zip(chunks, images)):
|
||||
if len(ck.strip()) == 0:
|
||||
continue
|
||||
logging.debug("-- {}".format(ck))
|
||||
d = copy.deepcopy(doc)
|
||||
d["image"] = image
|
||||
add_positions(d, [[ii]*5])
|
||||
tokenize(d, ck, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def tokenize_table(tbls, doc, eng, batch_size=10):
|
||||
res = []
|
||||
# add tables
|
||||
for (img, rows), poss in tbls:
|
||||
if not rows:
|
||||
continue
|
||||
if isinstance(rows, str):
|
||||
d = copy.deepcopy(doc)
|
||||
tokenize(d, rows, eng)
|
||||
d["content_with_weight"] = rows
|
||||
if img:
|
||||
d["image"] = img
|
||||
d["doc_type_kwd"] = "image"
|
||||
if poss:
|
||||
add_positions(d, poss)
|
||||
res.append(d)
|
||||
continue
|
||||
de = "; " if eng else "; "
|
||||
for i in range(0, len(rows), batch_size):
|
||||
d = copy.deepcopy(doc)
|
||||
r = de.join(rows[i:i + batch_size])
|
||||
tokenize(d, r, eng)
|
||||
if img:
|
||||
d["image"] = img
|
||||
d["doc_type_kwd"] = "image"
|
||||
add_positions(d, poss)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def add_positions(d, poss):
|
||||
if not poss:
|
||||
return
|
||||
page_num_int = []
|
||||
position_int = []
|
||||
top_int = []
|
||||
for pn, left, right, top, bottom in poss:
|
||||
page_num_int.append(int(pn + 1))
|
||||
top_int.append(int(top))
|
||||
position_int.append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
|
||||
d["page_num_int"] = page_num_int
|
||||
d["position_int"] = position_int
|
||||
d["top_int"] = top_int
|
||||
|
||||
|
||||
def remove_contents_table(sections, eng=False):
|
||||
i = 0
|
||||
while i < len(sections):
|
||||
def get(i):
|
||||
nonlocal sections
|
||||
return (sections[i] if isinstance(sections[i],
|
||||
type("")) else sections[i][0]).strip()
|
||||
|
||||
if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$",
|
||||
re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], flags=re.IGNORECASE)):
|
||||
i += 1
|
||||
continue
|
||||
sections.pop(i)
|
||||
if i >= len(sections):
|
||||
break
|
||||
prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2])
|
||||
while not prefix:
|
||||
sections.pop(i)
|
||||
if i >= len(sections):
|
||||
break
|
||||
prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2])
|
||||
sections.pop(i)
|
||||
if i >= len(sections) or not prefix:
|
||||
break
|
||||
for j in range(i, min(i + 128, len(sections))):
|
||||
if not re.match(prefix, get(j)):
|
||||
continue
|
||||
for _ in range(i, j):
|
||||
sections.pop(i)
|
||||
break
|
||||
|
||||
|
||||
def make_colon_as_title(sections):
|
||||
if not sections:
|
||||
return []
|
||||
if isinstance(sections[0], type("")):
|
||||
return sections
|
||||
i = 0
|
||||
while i < len(sections):
|
||||
txt, layout = sections[i]
|
||||
i += 1
|
||||
txt = txt.split("@")[0].strip()
|
||||
if not txt:
|
||||
continue
|
||||
if txt[-1] not in "::":
|
||||
continue
|
||||
txt = txt[::-1]
|
||||
arr = re.split(r"([。?!!?;;]| \.)", txt)
|
||||
if len(arr) < 2 or len(arr[1]) < 32:
|
||||
continue
|
||||
sections.insert(i - 1, (arr[0][::-1], "title"))
|
||||
i += 1
|
||||
|
||||
|
||||
def title_frequency(bull, sections):
|
||||
bullets_size = len(BULLET_PATTERN[bull])
|
||||
levels = [bullets_size + 1 for _ in range(len(sections))]
|
||||
if not sections or bull < 0:
|
||||
return bullets_size + 1, levels
|
||||
|
||||
for i, (txt, layout) in enumerate(sections):
|
||||
for j, p in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(p, txt.strip()) and not not_bullet(txt):
|
||||
levels[i] = j
|
||||
break
|
||||
else:
|
||||
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]):
|
||||
levels[i] = bullets_size
|
||||
most_level = bullets_size + 1
|
||||
for level, c in sorted(Counter(levels).items(), key=lambda x: x[1] * -1):
|
||||
if level <= bullets_size:
|
||||
most_level = level
|
||||
break
|
||||
return most_level, levels
|
||||
|
||||
|
||||
def not_title(txt):
|
||||
if re.match(r"第[零一二三四五六七八九十百0-9]+条", txt):
|
||||
return False
|
||||
if len(txt.split()) > 12 or (txt.find(" ") < 0 and len(txt) >= 32):
|
||||
return True
|
||||
return re.search(r"[,;,。;!!]", txt)
|
||||
|
||||
def tree_merge(bull, sections, depth):
|
||||
|
||||
if not sections or bull < 0:
|
||||
return sections
|
||||
if isinstance(sections[0], type("")):
|
||||
sections = [(s, "") for s in sections]
|
||||
|
||||
# filter out position information in pdf sections
|
||||
sections = [(t, o) for t, o in sections if
|
||||
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())]
|
||||
|
||||
def get_level(bull, section):
|
||||
text, layout = section
|
||||
text = re.sub(r"\u3000", " ", text).strip()
|
||||
|
||||
for i, title in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(title, text.strip()):
|
||||
return i+1, text
|
||||
else:
|
||||
if re.search(r"(title|head)", layout) and not not_title(text):
|
||||
return len(BULLET_PATTERN[bull])+1, text
|
||||
else:
|
||||
return len(BULLET_PATTERN[bull])+2, text
|
||||
level_set = set()
|
||||
lines = []
|
||||
for section in sections:
|
||||
level, text = get_level(bull, section)
|
||||
if not text.strip("\n"):
|
||||
continue
|
||||
|
||||
lines.append((level, text))
|
||||
level_set.add(level)
|
||||
|
||||
sorted_levels = sorted(list(level_set))
|
||||
|
||||
if depth <= len(sorted_levels):
|
||||
target_level = sorted_levels[depth - 1]
|
||||
else:
|
||||
target_level = sorted_levels[-1]
|
||||
|
||||
if target_level == len(BULLET_PATTERN[bull]) + 2:
|
||||
target_level = sorted_levels[-2] if len(sorted_levels) > 1 else sorted_levels[0]
|
||||
|
||||
root = Node(level=0, depth=target_level, texts=[])
|
||||
root.build_tree(lines)
|
||||
|
||||
return [("\n").join(element) for element in root.get_tree() if element]
|
||||
|
||||
def hierarchical_merge(bull, sections, depth):
|
||||
|
||||
if not sections or bull < 0:
|
||||
return []
|
||||
if isinstance(sections[0], type("")):
|
||||
sections = [(s, "") for s in sections]
|
||||
sections = [(t, o) for t, o in sections if
|
||||
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())]
|
||||
bullets_size = len(BULLET_PATTERN[bull])
|
||||
levels = [[] for _ in range(bullets_size + 2)]
|
||||
|
||||
for i, (txt, layout) in enumerate(sections):
|
||||
for j, p in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(p, txt.strip()):
|
||||
levels[j].append(i)
|
||||
break
|
||||
else:
|
||||
if re.search(r"(title|head)", layout) and not not_title(txt):
|
||||
levels[bullets_size].append(i)
|
||||
else:
|
||||
levels[bullets_size + 1].append(i)
|
||||
sections = [t for t, _ in sections]
|
||||
|
||||
# for s in sections: print("--", s)
|
||||
|
||||
def binary_search(arr, target):
|
||||
if not arr:
|
||||
return -1
|
||||
if target > arr[-1]:
|
||||
return len(arr) - 1
|
||||
if target < arr[0]:
|
||||
return -1
|
||||
s, e = 0, len(arr)
|
||||
while e - s > 1:
|
||||
i = (e + s) // 2
|
||||
if target > arr[i]:
|
||||
s = i
|
||||
continue
|
||||
elif target < arr[i]:
|
||||
e = i
|
||||
continue
|
||||
else:
|
||||
assert False
|
||||
return s
|
||||
|
||||
cks = []
|
||||
readed = [False] * len(sections)
|
||||
levels = levels[::-1]
|
||||
for i, arr in enumerate(levels[:depth]):
|
||||
for j in arr:
|
||||
if readed[j]:
|
||||
continue
|
||||
readed[j] = True
|
||||
cks.append([j])
|
||||
if i + 1 == len(levels) - 1:
|
||||
continue
|
||||
for ii in range(i + 1, len(levels)):
|
||||
jj = binary_search(levels[ii], j)
|
||||
if jj < 0:
|
||||
continue
|
||||
if levels[ii][jj] > cks[-1][-1]:
|
||||
cks[-1].pop(-1)
|
||||
cks[-1].append(levels[ii][jj])
|
||||
for ii in cks[-1]:
|
||||
readed[ii] = True
|
||||
|
||||
if not cks:
|
||||
return cks
|
||||
|
||||
for i in range(len(cks)):
|
||||
cks[i] = [sections[j] for j in cks[i][::-1]]
|
||||
logging.debug("\n* ".join(cks[i]))
|
||||
|
||||
res = [[]]
|
||||
num = [0]
|
||||
for ck in cks:
|
||||
if len(ck) == 1:
|
||||
n = num_tokens_from_string(re.sub(r"@@[0-9]+.*", "", ck[0]))
|
||||
if n + num[-1] < 218:
|
||||
res[-1].append(ck[0])
|
||||
num[-1] += n
|
||||
continue
|
||||
res.append(ck)
|
||||
num.append(n)
|
||||
continue
|
||||
res.append(ck)
|
||||
num.append(218)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
|
||||
from app.core.rag.deepdoc.parser.pdf_parser import RAGPdfParser
|
||||
if not sections:
|
||||
return []
|
||||
if isinstance(sections, str):
|
||||
sections = [sections]
|
||||
if isinstance(sections[0], str):
|
||||
sections = [(s, "") for s in sections]
|
||||
cks = [""]
|
||||
tk_nums = [0]
|
||||
|
||||
def add_chunk(t, pos):
|
||||
nonlocal cks, tk_nums, delimiter
|
||||
tnum = num_tokens_from_string(t)
|
||||
if not pos:
|
||||
pos = ""
|
||||
if tnum < 8:
|
||||
pos = ""
|
||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
||||
if cks:
|
||||
overlapped = RAGPdfParser.remove_tag(cks[-1])
|
||||
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
||||
if t.find(pos) < 0:
|
||||
t += pos
|
||||
cks.append(t)
|
||||
tk_nums.append(tnum)
|
||||
else:
|
||||
if cks[-1].find(pos) < 0:
|
||||
t += pos
|
||||
cks[-1] += t
|
||||
tk_nums[-1] += tnum
|
||||
|
||||
dels = get_delimiters(delimiter)
|
||||
for sec, pos in sections:
|
||||
if num_tokens_from_string(sec) < chunk_token_num:
|
||||
add_chunk("\n"+sec, pos)
|
||||
continue
|
||||
split_sec = re.split(r"(%s)" % dels, sec, flags=re.DOTALL)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk("\n"+sub_sec, pos)
|
||||
|
||||
return cks
|
||||
|
||||
|
||||
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
|
||||
from app.core.rag.deepdoc.parser.pdf_parser import RAGPdfParser
|
||||
if not texts or len(texts) != len(images):
|
||||
return [], []
|
||||
cks = [""]
|
||||
result_images = [None]
|
||||
tk_nums = [0]
|
||||
|
||||
def add_chunk(t, image, pos=""):
|
||||
nonlocal cks, result_images, tk_nums, delimiter
|
||||
tnum = num_tokens_from_string(t)
|
||||
if not pos:
|
||||
pos = ""
|
||||
if tnum < 8:
|
||||
pos = ""
|
||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
||||
if cks:
|
||||
overlapped = RAGPdfParser.remove_tag(cks[-1])
|
||||
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
||||
if t.find(pos) < 0:
|
||||
t += pos
|
||||
cks.append(t)
|
||||
result_images.append(image)
|
||||
tk_nums.append(tnum)
|
||||
else:
|
||||
if cks[-1].find(pos) < 0:
|
||||
t += pos
|
||||
cks[-1] += t
|
||||
if result_images[-1] is None:
|
||||
result_images[-1] = image
|
||||
else:
|
||||
result_images[-1] = concat_img(result_images[-1], image)
|
||||
tk_nums[-1] += tnum
|
||||
|
||||
dels = get_delimiters(delimiter)
|
||||
for text, image in zip(texts, images):
|
||||
# if text is tuple, unpack it
|
||||
if isinstance(text, tuple):
|
||||
text_str = text[0]
|
||||
text_pos = text[1] if len(text) > 1 else ""
|
||||
split_sec = re.split(r"(%s)" % dels, text_str)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk("\n"+sub_sec, image, text_pos)
|
||||
else:
|
||||
split_sec = re.split(r"(%s)" % dels, text)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk("\n"+sub_sec, image)
|
||||
|
||||
return cks, result_images
|
||||
|
||||
def docx_question_level(p, bull=-1):
|
||||
txt = re.sub(r"\u3000", " ", p.text).strip()
|
||||
if p.style.name.startswith('Heading'):
|
||||
return int(p.style.name.split(' ')[-1]), txt
|
||||
else:
|
||||
if bull < 0:
|
||||
return 0, txt
|
||||
for j, title in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(title, txt):
|
||||
return j + 1, txt
|
||||
return len(BULLET_PATTERN[bull])+1, txt
|
||||
|
||||
|
||||
def concat_img(img1, img2):
|
||||
if img1 and not img2:
|
||||
return img1
|
||||
if not img1 and img2:
|
||||
return img2
|
||||
if not img1 and not img2:
|
||||
return None
|
||||
|
||||
if img1 is img2:
|
||||
return img1
|
||||
|
||||
if isinstance(img1, Image.Image) and isinstance(img2, Image.Image):
|
||||
pixel_data1 = img1.tobytes()
|
||||
pixel_data2 = img2.tobytes()
|
||||
if pixel_data1 == pixel_data2:
|
||||
return img1
|
||||
|
||||
width1, height1 = img1.size
|
||||
width2, height2 = img2.size
|
||||
|
||||
new_width = max(width1, width2)
|
||||
new_height = height1 + height2
|
||||
new_image = Image.new('RGB', (new_width, new_height))
|
||||
|
||||
new_image.paste(img1, (0, 0))
|
||||
new_image.paste(img2, (0, height1))
|
||||
return new_image
|
||||
|
||||
|
||||
def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
if not sections:
|
||||
return [], []
|
||||
|
||||
cks = [""]
|
||||
images = [None]
|
||||
tk_nums = [0]
|
||||
|
||||
def add_chunk(t, image, pos=""):
|
||||
nonlocal cks, tk_nums, delimiter
|
||||
tnum = num_tokens_from_string(t)
|
||||
if tnum < 8:
|
||||
pos = ""
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num:
|
||||
if t.find(pos) < 0:
|
||||
t += pos
|
||||
cks.append(t)
|
||||
images.append(image)
|
||||
tk_nums.append(tnum)
|
||||
else:
|
||||
if cks[-1].find(pos) < 0:
|
||||
t += pos
|
||||
cks[-1] += t
|
||||
images[-1] = concat_img(images[-1], image)
|
||||
tk_nums[-1] += tnum
|
||||
|
||||
dels = get_delimiters(delimiter)
|
||||
line = ""
|
||||
for sec, image in sections:
|
||||
if not image:
|
||||
line += sec + "\n"
|
||||
continue
|
||||
split_sec = re.split(r"(%s)" % dels, line + sec)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk("\n"+sub_sec, image,"")
|
||||
line = ""
|
||||
|
||||
if line:
|
||||
split_sec = re.split(r"(%s)" % dels, line)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk("\n"+sub_sec, image,"")
|
||||
|
||||
return cks, images
|
||||
|
||||
|
||||
def extract_between(text: str, start_tag: str, end_tag: str) -> list[str]:
|
||||
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
|
||||
return re.findall(pattern, text, flags=re.DOTALL)
|
||||
|
||||
|
||||
def get_delimiters(delimiters: str):
|
||||
dels = []
|
||||
s = 0
|
||||
for m in re.finditer(r"`([^`]+)`", delimiters, re.I):
|
||||
f, t = m.span()
|
||||
dels.append(m.group(1))
|
||||
dels.extend(list(delimiters[s: f]))
|
||||
s = t
|
||||
if s < len(delimiters):
|
||||
dels.extend(list(delimiters[s:]))
|
||||
|
||||
dels.sort(key=lambda x: -len(x))
|
||||
dels = [re.escape(d) for d in dels if d]
|
||||
dels = [d for d in dels if d]
|
||||
dels_pattern = "|".join(dels)
|
||||
|
||||
return dels_pattern
|
||||
|
||||
class Node:
|
||||
def __init__(self, level, depth=-1, texts=None):
|
||||
self.level = level
|
||||
self.depth = depth
|
||||
self.texts = texts or []
|
||||
self.children = []
|
||||
|
||||
def add_child(self, child_node):
|
||||
self.children.append(child_node)
|
||||
|
||||
def get_children(self):
|
||||
return self.children
|
||||
|
||||
def get_level(self):
|
||||
return self.level
|
||||
|
||||
def get_texts(self):
|
||||
return self.texts
|
||||
|
||||
def set_texts(self, texts):
|
||||
self.texts = texts
|
||||
|
||||
def add_text(self, text):
|
||||
self.texts.append(text)
|
||||
|
||||
def clear_text(self):
|
||||
self.texts = []
|
||||
|
||||
def __repr__(self):
|
||||
return f"Node(level={self.level}, texts={self.texts}, children={len(self.children)})"
|
||||
|
||||
def build_tree(self, lines):
|
||||
stack = [self]
|
||||
for level, text in lines:
|
||||
if self.depth != -1 and level > self.depth:
|
||||
# Beyond target depth: merge content into the current leaf instead of creating deeper nodes
|
||||
stack[-1].add_text(text)
|
||||
continue
|
||||
|
||||
# Move up until we find the proper parent whose level is strictly smaller than current
|
||||
while len(stack) > 1 and level <= stack[-1].get_level():
|
||||
stack.pop()
|
||||
|
||||
node = Node(level=level, texts=[text])
|
||||
# Attach as child of current parent and descend
|
||||
stack[-1].add_child(node)
|
||||
stack.append(node)
|
||||
|
||||
return self
|
||||
|
||||
def get_tree(self):
|
||||
tree_list = []
|
||||
self._dfs(self, tree_list, [])
|
||||
return tree_list
|
||||
|
||||
def _dfs(self, node, tree_list, titles):
|
||||
level = node.get_level()
|
||||
texts = node.get_texts()
|
||||
child = node.get_children()
|
||||
|
||||
if level == 0 and texts:
|
||||
tree_list.append("\n".join(titles+texts))
|
||||
|
||||
# Titles within configured depth are accumulated into the current path
|
||||
if 1 <= level <= self.depth:
|
||||
path_titles = titles + texts
|
||||
else:
|
||||
path_titles = titles
|
||||
|
||||
# Body outside the depth limit becomes its own chunk under the current title path
|
||||
if level > self.depth and texts:
|
||||
tree_list.append("\n".join(path_titles + texts))
|
||||
|
||||
# A leaf title within depth emits its title path as a chunk (header-only section)
|
||||
elif not child and (1 <= level <= self.depth):
|
||||
tree_list.append("\n".join(path_titles))
|
||||
|
||||
# Recurse into children with the updated title path
|
||||
for c in child:
|
||||
self._dfs(c, tree_list, path_titles)
|
||||
261
api/app/core/rag/nlp/query.py
Normal file
261
api/app/core/rag/nlp/query.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
from app.core.rag.utils.doc_store_conn import MatchTextExpr
|
||||
from . import rag_tokenizer, term_weight, synonym
|
||||
|
||||
|
||||
class FulltextQueryer:
|
||||
def __init__(self):
|
||||
self.tw = term_weight.Dealer()
|
||||
self.syn = synonym.Dealer()
|
||||
self.query_fields = [
|
||||
"title_tks^10",
|
||||
"title_sm_tks^5",
|
||||
"important_kwd^30",
|
||||
"important_tks^20",
|
||||
"question_tks^20",
|
||||
"content_ltks^2",
|
||||
"content_sm_ltks",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def subSpecialChar(line):
|
||||
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
|
||||
|
||||
@staticmethod
|
||||
def isChinese(line):
|
||||
arr = re.split(r"[ \t]+", line)
|
||||
if len(arr) <= 3:
|
||||
return True
|
||||
e = 0
|
||||
for t in arr:
|
||||
if not re.match(r"[a-zA-Z]+$", t):
|
||||
e += 1
|
||||
return e * 1.0 / len(arr) >= 0.7
|
||||
|
||||
@staticmethod
|
||||
def rmWWW(txt):
|
||||
patts = [
|
||||
(
|
||||
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
|
||||
"",
|
||||
),
|
||||
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
||||
(
|
||||
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
|
||||
" ")
|
||||
]
|
||||
otxt = txt
|
||||
for r, p in patts:
|
||||
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
||||
if not txt:
|
||||
txt = otxt
|
||||
return txt
|
||||
|
||||
@staticmethod
|
||||
def add_space_between_eng_zh(txt):
|
||||
# (ENG/ENG+NUM) + ZH
|
||||
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
|
||||
# ENG + ZH
|
||||
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
|
||||
# ZH + (ENG/ENG+NUM)
|
||||
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
|
||||
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
|
||||
return txt
|
||||
|
||||
def question(self, txt, tbl="qa", min_match: float = 0.6):
|
||||
txt = FulltextQueryer.add_space_between_eng_zh(txt)
|
||||
txt = re.sub(
|
||||
r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+",
|
||||
" ",
|
||||
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
|
||||
).strip()
|
||||
otxt = txt
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
|
||||
if not self.isChinese(txt):
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
tks = rag_tokenizer.tokenize(txt).split()
|
||||
keywords = [t for t in tks if t]
|
||||
tks_w = self.tw.weights(tks, preprocess=False)
|
||||
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
|
||||
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
|
||||
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
|
||||
tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()]
|
||||
syns = []
|
||||
for tk, w in tks_w[:256]:
|
||||
syn = self.syn.lookup(tk)
|
||||
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
|
||||
keywords.extend(syn)
|
||||
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
|
||||
syns.append(" ".join(syn))
|
||||
|
||||
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
|
||||
tk and not re.match(r"[.^+\(\)-]", tk)]
|
||||
for i in range(1, len(tks_w)):
|
||||
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
|
||||
if not left or not right:
|
||||
continue
|
||||
q.append(
|
||||
'"%s %s"^%.4f'
|
||||
% (
|
||||
tks_w[i - 1][0],
|
||||
tks_w[i][0],
|
||||
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
|
||||
)
|
||||
)
|
||||
if not q:
|
||||
q.append(txt)
|
||||
query = " ".join(q)
|
||||
return MatchTextExpr(
|
||||
self.query_fields, query, 100
|
||||
), keywords
|
||||
|
||||
def need_fine_grained_tokenize(tk):
|
||||
if len(tk) < 3:
|
||||
return False
|
||||
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
|
||||
return False
|
||||
return True
|
||||
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
qs, keywords = [], []
|
||||
for tt in self.tw.split(txt)[:256]: # .split():
|
||||
if not tt:
|
||||
continue
|
||||
keywords.append(tt)
|
||||
twts = self.tw.weights([tt])
|
||||
syns = self.syn.lookup(tt)
|
||||
if syns and len(keywords) < 32:
|
||||
keywords.extend(syns)
|
||||
logging.debug(json.dumps(twts, ensure_ascii=False))
|
||||
tms = []
|
||||
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
||||
sm = (
|
||||
rag_tokenizer.fine_grained_tokenize(tk).split()
|
||||
if need_fine_grained_tokenize(tk)
|
||||
else []
|
||||
)
|
||||
sm = [
|
||||
re.sub(
|
||||
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
|
||||
"",
|
||||
m,
|
||||
)
|
||||
for m in sm
|
||||
]
|
||||
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
||||
sm = [m for m in sm if len(m) > 1]
|
||||
|
||||
if len(keywords) < 32:
|
||||
keywords.append(re.sub(r"[ \\\"']+", "", tk))
|
||||
keywords.extend(sm)
|
||||
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||
if len(keywords) < 32:
|
||||
keywords.extend([s for s in tk_syns if s])
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
||||
|
||||
if len(keywords) >= 32:
|
||||
break
|
||||
|
||||
tk = FulltextQueryer.subSpecialChar(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
|
||||
if sm:
|
||||
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
|
||||
if tk.strip():
|
||||
tms.append((tk, w))
|
||||
|
||||
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
||||
|
||||
if len(twts) > 1:
|
||||
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
|
||||
|
||||
syns = " OR ".join(
|
||||
[
|
||||
'"%s"'
|
||||
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
|
||||
for s in syns
|
||||
]
|
||||
)
|
||||
if syns and tms:
|
||||
tms = f"({tms})^5 OR ({syns})^0.7"
|
||||
|
||||
qs.append(tms)
|
||||
|
||||
if qs:
|
||||
query = " OR ".join([f"({t})" for t in qs if t])
|
||||
if not query:
|
||||
query = otxt
|
||||
return MatchTextExpr(
|
||||
self.query_fields, query, 100, {"minimum_should_match": min_match}
|
||||
), keywords
|
||||
return None, keywords
|
||||
|
||||
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
|
||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||
import numpy as np
|
||||
|
||||
sims = CosineSimilarity([avec], bvecs)
|
||||
tksim = self.token_similarity(atks, btkss)
|
||||
if np.sum(sims[0]) == 0:
|
||||
return np.array(tksim), tksim, sims[0]
|
||||
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
|
||||
|
||||
def token_similarity(self, atks, btkss):
|
||||
def toDict(tks):
|
||||
if isinstance(tks, str):
|
||||
tks = tks.split()
|
||||
d = defaultdict(int)
|
||||
wts = self.tw.weights(tks, preprocess=False)
|
||||
for i, (t, c) in enumerate(wts):
|
||||
d[t] += c
|
||||
return d
|
||||
|
||||
atks = toDict(atks)
|
||||
btkss = [toDict(tks) for tks in btkss]
|
||||
return [self.similarity(atks, btks) for btks in btkss]
|
||||
|
||||
def similarity(self, qtwt, dtwt):
|
||||
if isinstance(dtwt, type("")):
|
||||
dtwt = {t: w for t, w in self.tw.weights(self.tw.split(dtwt), preprocess=False)}
|
||||
if isinstance(qtwt, type("")):
|
||||
qtwt = {t: w for t, w in self.tw.weights(self.tw.split(qtwt), preprocess=False)}
|
||||
s = 1e-9
|
||||
for k, v in qtwt.items():
|
||||
if k in dtwt:
|
||||
s += v #* dtwt[k]
|
||||
q = 1e-9
|
||||
for k, v in qtwt.items():
|
||||
q += v #* v
|
||||
return s/q #math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 )))
|
||||
|
||||
def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
|
||||
if isinstance(content_tks, str):
|
||||
content_tks = [c.strip() for c in content_tks.strip() if c.strip()]
|
||||
tks_w = self.tw.weights(content_tks, preprocess=False)
|
||||
|
||||
keywords = [f'"{k.strip()}"' for k in keywords]
|
||||
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
||||
tk = FulltextQueryer.subSpecialChar(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
|
||||
if tk:
|
||||
keywords.append(f"{tk}^{w}")
|
||||
|
||||
return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
|
||||
{"minimum_should_match": min(3, len(keywords) // 10)})
|
||||
499
api/app/core/rag/nlp/rag_tokenizer.py
Normal file
499
api/app/core/rag/nlp/rag_tokenizer.py
Normal file
@@ -0,0 +1,499 @@
|
||||
import logging
|
||||
import copy
|
||||
import datrie
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
from hanziconv import HanziConv
|
||||
from nltk import word_tokenize
|
||||
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
||||
from app.core.rag.common.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class RagTokenizer:
|
||||
def key_(self, line):
|
||||
return str(line.lower().encode("utf-8"))[2:-1]
|
||||
|
||||
def rkey_(self, line):
|
||||
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
|
||||
|
||||
def loadDict_(self, fnm):
|
||||
logging.info(f"[HUQIE]:Build trie from {fnm}")
|
||||
try:
|
||||
of = open(fnm, "r", encoding='utf-8')
|
||||
while True:
|
||||
line = of.readline()
|
||||
if not line:
|
||||
break
|
||||
line = re.sub(r"[\r\n]+", "", line)
|
||||
line = re.split(r"[ \t]", line)
|
||||
k = self.key_(line[0])
|
||||
F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
|
||||
if k not in self.trie_ or self.trie_[k][0] < F:
|
||||
self.trie_[self.key_(line[0])] = (F, line[2])
|
||||
self.trie_[self.rkey_(line[0])] = 1
|
||||
|
||||
trie_file_name = os.path.join(get_project_base_directory(), "res", "huqie") + ".txt.trie"
|
||||
logging.info(f"[HUQIE]:Build trie cache to {trie_file_name}")
|
||||
self.trie_.save(trie_file_name)
|
||||
of.close()
|
||||
except Exception:
|
||||
logging.exception(f"[HUQIE]:Build trie {fnm} failed")
|
||||
|
||||
def __init__(self, debug=False):
|
||||
self.DEBUG = debug
|
||||
self.DENOMINATOR = 1000000
|
||||
|
||||
self.stemmer = PorterStemmer()
|
||||
self.lemmatizer = WordNetLemmatizer()
|
||||
|
||||
self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)"
|
||||
|
||||
trie_file_name = os.path.join(get_project_base_directory(), "res", "huqie") + ".txt.trie"
|
||||
# check if trie file existence
|
||||
if os.path.exists(trie_file_name):
|
||||
try:
|
||||
# load trie from file
|
||||
self.trie_ = datrie.Trie.load(trie_file_name)
|
||||
return
|
||||
except Exception:
|
||||
# fail to load trie from file, build default trie
|
||||
logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
else:
|
||||
# file not exist, build default trie
|
||||
logging.info(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
|
||||
# load data from dict file and save to trie file
|
||||
self.loadDict_(os.path.join(get_project_base_directory(), "app/core/rag/res", "huqie") + ".txt")
|
||||
|
||||
def loadUserDict(self, fnm):
|
||||
try:
|
||||
self.trie_ = datrie.Trie.load(fnm + ".trie")
|
||||
return
|
||||
except Exception:
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
self.loadDict_(fnm)
|
||||
|
||||
def addUserDict(self, fnm):
|
||||
self.loadDict_(fnm)
|
||||
|
||||
def _strQ2B(self, ustring):
|
||||
"""Convert full-width characters to half-width characters"""
|
||||
rstring = ""
|
||||
for uchar in ustring:
|
||||
inside_code = ord(uchar)
|
||||
if inside_code == 0x3000:
|
||||
inside_code = 0x0020
|
||||
else:
|
||||
inside_code -= 0xfee0
|
||||
if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character.
|
||||
rstring += uchar
|
||||
else:
|
||||
rstring += chr(inside_code)
|
||||
return rstring
|
||||
|
||||
def _tradi2simp(self, line):
|
||||
return HanziConv.toSimplified(line)
|
||||
|
||||
def dfs_(self, chars, s, preTks, tkslist, _depth=0, _memo=None):
|
||||
if _memo is None:
|
||||
_memo = {}
|
||||
MAX_DEPTH = 10
|
||||
if _depth > MAX_DEPTH:
|
||||
if s < len(chars):
|
||||
copy_pretks = copy.deepcopy(preTks)
|
||||
remaining = "".join(chars[s:])
|
||||
copy_pretks.append((remaining, (-12, '')))
|
||||
tkslist.append(copy_pretks)
|
||||
return s
|
||||
|
||||
state_key = (s, tuple(tk[0] for tk in preTks)) if preTks else (s, None)
|
||||
if state_key in _memo:
|
||||
return _memo[state_key]
|
||||
|
||||
res = s
|
||||
if s >= len(chars):
|
||||
tkslist.append(preTks)
|
||||
_memo[state_key] = s
|
||||
return s
|
||||
if s < len(chars) - 4:
|
||||
is_repetitive = True
|
||||
char_to_check = chars[s]
|
||||
for i in range(1, 5):
|
||||
if s + i >= len(chars) or chars[s + i] != char_to_check:
|
||||
is_repetitive = False
|
||||
break
|
||||
if is_repetitive:
|
||||
end = s
|
||||
while end < len(chars) and chars[end] == char_to_check:
|
||||
end += 1
|
||||
mid = s + min(10, end - s)
|
||||
t = "".join(chars[s:mid])
|
||||
k = self.key_(t)
|
||||
copy_pretks = copy.deepcopy(preTks)
|
||||
if k in self.trie_:
|
||||
copy_pretks.append((t, self.trie_[k]))
|
||||
else:
|
||||
copy_pretks.append((t, (-12, '')))
|
||||
next_res = self.dfs_(chars, mid, copy_pretks, tkslist, _depth + 1, _memo)
|
||||
res = max(res, next_res)
|
||||
_memo[state_key] = res
|
||||
return res
|
||||
|
||||
S = s + 1
|
||||
if s + 2 <= len(chars):
|
||||
t1 = "".join(chars[s:s + 1])
|
||||
t2 = "".join(chars[s:s + 2])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
|
||||
S = s + 2
|
||||
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
|
||||
t1 = preTks[-1][0] + "".join(chars[s:s + 1])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)):
|
||||
S = s + 2
|
||||
|
||||
for e in range(S, len(chars) + 1):
|
||||
t = "".join(chars[s:e])
|
||||
k = self.key_(t)
|
||||
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
|
||||
break
|
||||
if k in self.trie_:
|
||||
pretks = copy.deepcopy(preTks)
|
||||
pretks.append((t, self.trie_[k]))
|
||||
res = max(res, self.dfs_(chars, e, pretks, tkslist, _depth + 1, _memo))
|
||||
|
||||
if res > s:
|
||||
_memo[state_key] = res
|
||||
return res
|
||||
|
||||
t = "".join(chars[s:s + 1])
|
||||
k = self.key_(t)
|
||||
copy_pretks = copy.deepcopy(preTks)
|
||||
if k in self.trie_:
|
||||
copy_pretks.append((t, self.trie_[k]))
|
||||
else:
|
||||
copy_pretks.append((t, (-12, '')))
|
||||
result = self.dfs_(chars, s + 1, copy_pretks, tkslist, _depth + 1, _memo)
|
||||
_memo[state_key] = result
|
||||
return result
|
||||
|
||||
def freq(self, tk):
|
||||
k = self.key_(tk)
|
||||
if k not in self.trie_:
|
||||
return 0
|
||||
return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
|
||||
|
||||
def tag(self, tk):
|
||||
k = self.key_(tk)
|
||||
if k not in self.trie_:
|
||||
return ""
|
||||
return self.trie_[k][1]
|
||||
|
||||
def score_(self, tfts):
|
||||
B = 30
|
||||
F, L, tks = 0, 0, []
|
||||
for tk, (freq, tag) in tfts:
|
||||
F += freq
|
||||
L += 0 if len(tk) < 2 else 1
|
||||
tks.append(tk)
|
||||
#F /= len(tks)
|
||||
L /= len(tks)
|
||||
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
|
||||
return tks, B / len(tks) + L + F
|
||||
|
||||
def sortTks_(self, tkslist):
|
||||
res = []
|
||||
for tfts in tkslist:
|
||||
tks, s = self.score_(tfts)
|
||||
res.append((tks, s))
|
||||
return sorted(res, key=lambda x: x[1], reverse=True)
|
||||
|
||||
def merge_(self, tks):
|
||||
# if split chars is part of token
|
||||
res = []
|
||||
tks = re.sub(r"[ ]+", " ", tks).split()
|
||||
s = 0
|
||||
while True:
|
||||
if s >= len(tks):
|
||||
break
|
||||
E = s + 1
|
||||
for e in range(s + 2, min(len(tks) + 2, s + 6)):
|
||||
tk = "".join(tks[s:e])
|
||||
if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
|
||||
E = e
|
||||
res.append("".join(tks[s:E]))
|
||||
s = E
|
||||
|
||||
return " ".join(res)
|
||||
|
||||
def maxForward_(self, line):
|
||||
res = []
|
||||
s = 0
|
||||
while s < len(line):
|
||||
e = s + 1
|
||||
t = line[s:e]
|
||||
while e < len(line) and self.trie_.has_keys_with_prefix(
|
||||
self.key_(t)):
|
||||
e += 1
|
||||
t = line[s:e]
|
||||
|
||||
while e - 1 > s and self.key_(t) not in self.trie_:
|
||||
e -= 1
|
||||
t = line[s:e]
|
||||
|
||||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, '')))
|
||||
|
||||
s = e
|
||||
|
||||
return self.score_(res)
|
||||
|
||||
def maxBackward_(self, line):
|
||||
res = []
|
||||
s = len(line) - 1
|
||||
while s >= 0:
|
||||
e = s + 1
|
||||
t = line[s:e]
|
||||
while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
|
||||
s -= 1
|
||||
t = line[s:e]
|
||||
|
||||
while s + 1 < e and self.key_(t) not in self.trie_:
|
||||
s += 1
|
||||
t = line[s:e]
|
||||
|
||||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, '')))
|
||||
|
||||
s -= 1
|
||||
|
||||
return self.score_(res[::-1])
|
||||
|
||||
def english_normalize_(self, tks):
|
||||
return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
|
||||
|
||||
def _split_by_lang(self, line):
|
||||
txt_lang_pairs = []
|
||||
arr = re.split(self.SPLIT_CHAR, line)
|
||||
for a in arr:
|
||||
if not a:
|
||||
continue
|
||||
s = 0
|
||||
e = s + 1
|
||||
zh = is_chinese(a[s])
|
||||
while e < len(a):
|
||||
_zh = is_chinese(a[e])
|
||||
if _zh == zh:
|
||||
e += 1
|
||||
continue
|
||||
txt_lang_pairs.append((a[s: e], zh))
|
||||
s = e
|
||||
e = s + 1
|
||||
zh = _zh
|
||||
if s >= len(a):
|
||||
continue
|
||||
txt_lang_pairs.append((a[s: e], zh))
|
||||
return txt_lang_pairs
|
||||
|
||||
def tokenize(self, line):
|
||||
line = re.sub(r"\W+", " ", line)
|
||||
line = self._strQ2B(line).lower()
|
||||
line = self._tradi2simp(line)
|
||||
|
||||
arr = self._split_by_lang(line)
|
||||
res = []
|
||||
for L,lang in arr:
|
||||
if not lang:
|
||||
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
|
||||
continue
|
||||
if len(L) < 2 or re.match(
|
||||
r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
|
||||
res.append(L)
|
||||
continue
|
||||
|
||||
# use maxforward for the first time
|
||||
tks, s = self.maxForward_(L)
|
||||
tks1, s1 = self.maxBackward_(L)
|
||||
if self.DEBUG:
|
||||
logging.debug("[FW] {} {}".format(tks, s))
|
||||
logging.debug("[BW] {} {}".format(tks1, s1))
|
||||
|
||||
i, j, _i, _j = 0, 0, 0, 0
|
||||
same = 0
|
||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
same += 1
|
||||
if same > 0:
|
||||
res.append(" ".join(tks[j: j + same]))
|
||||
_i = i + same
|
||||
_j = j + same
|
||||
j = _j + 1
|
||||
i = _i + 1
|
||||
|
||||
while i < len(tks1) and j < len(tks):
|
||||
tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
|
||||
if tk1 != tk:
|
||||
if len(tk1) > len(tk):
|
||||
j += 1
|
||||
else:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if tks1[i] != tks[j]:
|
||||
i += 1
|
||||
j += 1
|
||||
continue
|
||||
# backward tokens from_i to i are different from forward tokens from _j to j.
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||
|
||||
same = 1
|
||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
same += 1
|
||||
res.append(" ".join(tks[j: j + same]))
|
||||
_i = i + same
|
||||
_j = j + same
|
||||
j = _j + 1
|
||||
i = _i + 1
|
||||
|
||||
if _i < len(tks1):
|
||||
assert _j < len(tks)
|
||||
assert "".join(tks1[_i:]) == "".join(tks[_j:])
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||
|
||||
res = " ".join(res)
|
||||
logging.debug("[TKS] {}".format(self.merge_(res)))
|
||||
return self.merge_(res)
|
||||
|
||||
def fine_grained_tokenize(self, tks):
|
||||
tks = tks.split()
|
||||
zh_num = len([1 for c in tks if c and is_chinese(c[0])])
|
||||
if zh_num < len(tks) * 0.2:
|
||||
res = []
|
||||
for tk in tks:
|
||||
res.extend(tk.split("/"))
|
||||
return " ".join(res)
|
||||
|
||||
res = []
|
||||
for tk in tks:
|
||||
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
|
||||
res.append(tk)
|
||||
continue
|
||||
tkslist = []
|
||||
if len(tk) > 10:
|
||||
tkslist.append(tk)
|
||||
else:
|
||||
self.dfs_(tk, 0, [], tkslist)
|
||||
if len(tkslist) < 2:
|
||||
res.append(tk)
|
||||
continue
|
||||
stk = self.sortTks_(tkslist)[1][0]
|
||||
if len(stk) == len(tk):
|
||||
stk = tk
|
||||
else:
|
||||
if re.match(r"[a-z\.-]+$", tk):
|
||||
for t in stk:
|
||||
if len(t) < 3:
|
||||
stk = tk
|
||||
break
|
||||
else:
|
||||
stk = " ".join(stk)
|
||||
else:
|
||||
stk = " ".join(stk)
|
||||
|
||||
res.append(stk)
|
||||
|
||||
return " ".join(self.english_normalize_(res))
|
||||
|
||||
|
||||
def is_chinese(s):
|
||||
if s >= u'\u4e00' and s <= u'\u9fa5':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_number(s):
|
||||
if s >= u'\u0030' and s <= u'\u0039':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_alphabet(s):
|
||||
if (s >= u'\u0041' and s <= u'\u005a') or (
|
||||
s >= u'\u0061' and s <= u'\u007a'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def naiveQie(txt):
|
||||
tks = []
|
||||
for t in txt.split():
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
|
||||
) and re.match(r".*[a-zA-Z]$", t):
|
||||
tks.append(" ")
|
||||
tks.append(t)
|
||||
return tks
|
||||
|
||||
|
||||
tokenizer = RagTokenizer()
|
||||
tokenize = tokenizer.tokenize
|
||||
fine_grained_tokenize = tokenizer.fine_grained_tokenize
|
||||
tag = tokenizer.tag
|
||||
freq = tokenizer.freq
|
||||
loadUserDict = tokenizer.loadUserDict
|
||||
addUserDict = tokenizer.addUserDict
|
||||
tradi2simp = tokenizer._tradi2simp
|
||||
strQ2B = tokenizer._strQ2B
|
||||
|
||||
if __name__ == '__main__':
|
||||
tknzr = RagTokenizer(debug=True)
|
||||
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
|
||||
tks = tknzr.tokenize(
|
||||
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("虽然我不怎么玩")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
if len(sys.argv) < 2:
|
||||
sys.exit()
|
||||
tknzr.DEBUG = False
|
||||
tknzr.loadUserDict(sys.argv[1])
|
||||
of = open(sys.argv[2], "r")
|
||||
while True:
|
||||
line = of.readline()
|
||||
if not line:
|
||||
break
|
||||
logging.info(tknzr.tokenize(line))
|
||||
of.close()
|
||||
192
api/app/core/rag/nlp/search.py
Normal file
192
api/app/core/rag/nlp/search.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import uuid
|
||||
from typing import Dict, List, Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from app.db import get_db
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models import RedBearLLM, RedBearRerank
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.models import knowledge_model
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.repositories import knowledge_repository, knowledgeshare_repository
|
||||
from app.services.model_service import ModelConfigService
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
|
||||
|
||||
def knowledge_retrieval(
|
||||
query: str,
|
||||
config: Dict[str, Any],
|
||||
user_ids: List[str] = None,
|
||||
) -> list[DocumentChunk]:
|
||||
"""
|
||||
Knowledge retrieval with multiple knowledge bases and reranking
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
config: Configuration dictionary containing:
|
||||
- knowledge_bases: List of knowledge base configs with:
|
||||
- kb_id: Knowledge base ID
|
||||
- similarity_threshold: float
|
||||
- vector_similarity_weight: float
|
||||
- top_k: int
|
||||
- retrieve_type: "participle" or "semantic" or "hybrid"
|
||||
- merge_strategy: "weight" or other strategies
|
||||
- reranker_id: UUID of the reranker to use
|
||||
- reranker_top_k: int
|
||||
|
||||
Returns:
|
||||
Rearranged document block list (in descending order of relevance)
|
||||
"""
|
||||
db = next(get_db()) # Manually call the generator
|
||||
try:
|
||||
# parse configuration
|
||||
knowledge_bases = config.get("knowledge_bases", [])
|
||||
merge_strategy = config.get("merge_strategy", "weight")
|
||||
reranker_id = config.get("reranker_id")
|
||||
reranker_top_k = config.get("reranker_top_k", 1024)
|
||||
|
||||
file_names_filter=[]
|
||||
if user_ids:
|
||||
file_names_filter.extend([f"{user_id}.txt" for user_id in user_ids])
|
||||
|
||||
if not knowledge_bases:
|
||||
return []
|
||||
|
||||
all_results = []
|
||||
# Search each knowledge base
|
||||
for kb_config in knowledge_bases:
|
||||
kb_id = kb_config["kb_id"]
|
||||
try:
|
||||
# Check whether the knowledge base exists and is available
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id)
|
||||
if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1:
|
||||
# Process shared knowledge base
|
||||
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
|
||||
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db,
|
||||
knowledgeshare_id=db_knowledge.id)
|
||||
if knowledgeshare:
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db,
|
||||
knowledge_id=knowledgeshare.source_kb_id)
|
||||
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
# Retrieve according to the configured retrieval type
|
||||
match kb_config["retrieve_type"]:
|
||||
case "participle":
|
||||
rs = vector_service.search_by_full_text(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["similarity_threshold"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
case "semantic":
|
||||
rs = vector_service.search_by_vector(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["vector_similarity_weight"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
case _: # hybrid
|
||||
rs1 = vector_service.search_by_vector(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["vector_similarity_weight"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
rs2 = vector_service.search_by_full_text(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["similarity_threshold"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
|
||||
# Deduplication of merge results
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
for doc in rs1 + rs2:
|
||||
if doc.metadata["doc_id"] not in seen_ids:
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = unique_rs
|
||||
|
||||
all_results.extend(rs)
|
||||
except Exception as e:
|
||||
# Failure of retrieval in a single knowledge base does not affect other knowledge bases
|
||||
print(f"retrieval knowledge({kb_id}) failed: {str(e)}")
|
||||
continue
|
||||
|
||||
# Use the specified reranker for re-ranking
|
||||
if reranker_id:
|
||||
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||
return all_results
|
||||
|
||||
except Exception as e:
|
||||
print(f"retrieval knowledge failed: {str(e)}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
||||
"""
|
||||
Reorder the list of document blocks and return the top_k results most relevant to the query
|
||||
Args:
|
||||
reranker_id: reranker model id
|
||||
query: query string
|
||||
docs: List of document blocks to be rearranged
|
||||
top_k: Number of top-level documents returned
|
||||
|
||||
Returns:
|
||||
Rearranged document block list (in descending order of relevance)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input document list is empty or top_k is invalid
|
||||
"""
|
||||
# 参数校验
|
||||
if not reranker_id:
|
||||
raise ValueError("reranker_id be empty")
|
||||
if not docs:
|
||||
raise ValueError("retrieval chunks be empty")
|
||||
if top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
try:
|
||||
# initialize reranker
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=reranker_id)
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
reranker = RedBearRerank(RedBearModelConfig(
|
||||
model_name=apiConfig.model_name,
|
||||
provider=apiConfig.provider,
|
||||
api_key=apiConfig.api_key,
|
||||
base_url=apiConfig.api_base
|
||||
))
|
||||
# Convert to LangChain Document object
|
||||
documents = [
|
||||
Document(
|
||||
page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
|
||||
metadata=doc.metadata or {} # Deal with possible None metadata
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
|
||||
# Perform reordering (compress_documents will automatically handle relevance scores and indexing)
|
||||
reranked_docs = list(reranker.compress_documents(documents, query))
|
||||
print(reranked_docs)
|
||||
|
||||
# Sort in descending order based on relevance score
|
||||
reranked_docs.sort(
|
||||
key=lambda x: x.metadata.get("relevance_score", 0),
|
||||
reverse=True
|
||||
)
|
||||
# Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
|
||||
result = []
|
||||
for item in reranked_docs[:top_k]:
|
||||
for doc in docs:
|
||||
if doc.page_content == item.page_content:
|
||||
doc.metadata["score"] = item.metadata["relevance_score"]
|
||||
result.append(doc)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
|
||||
126
api/app/core/rag/nlp/surname.py
Normal file
126
api/app/core/rag/nlp/surname.py
Normal file
@@ -0,0 +1,126 @@
|
||||
m = set(["赵","钱","孙","李",
|
||||
"周","吴","郑","王",
|
||||
"冯","陈","褚","卫",
|
||||
"蒋","沈","韩","杨",
|
||||
"朱","秦","尤","许",
|
||||
"何","吕","施","张",
|
||||
"孔","曹","严","华",
|
||||
"金","魏","陶","姜",
|
||||
"戚","谢","邹","喻",
|
||||
"柏","水","窦","章",
|
||||
"云","苏","潘","葛",
|
||||
"奚","范","彭","郎",
|
||||
"鲁","韦","昌","马",
|
||||
"苗","凤","花","方",
|
||||
"俞","任","袁","柳",
|
||||
"酆","鲍","史","唐",
|
||||
"费","廉","岑","薛",
|
||||
"雷","贺","倪","汤",
|
||||
"滕","殷","罗","毕",
|
||||
"郝","邬","安","常",
|
||||
"乐","于","时","傅",
|
||||
"皮","卞","齐","康",
|
||||
"伍","余","元","卜",
|
||||
"顾","孟","平","黄",
|
||||
"和","穆","萧","尹",
|
||||
"姚","邵","湛","汪",
|
||||
"祁","毛","禹","狄",
|
||||
"米","贝","明","臧",
|
||||
"计","伏","成","戴",
|
||||
"谈","宋","茅","庞",
|
||||
"熊","纪","舒","屈",
|
||||
"项","祝","董","梁",
|
||||
"杜","阮","蓝","闵",
|
||||
"席","季","麻","强",
|
||||
"贾","路","娄","危",
|
||||
"江","童","颜","郭",
|
||||
"梅","盛","林","刁",
|
||||
"钟","徐","邱","骆",
|
||||
"高","夏","蔡","田",
|
||||
"樊","胡","凌","霍",
|
||||
"虞","万","支","柯",
|
||||
"昝","管","卢","莫",
|
||||
"经","房","裘","缪",
|
||||
"干","解","应","宗",
|
||||
"丁","宣","贲","邓",
|
||||
"郁","单","杭","洪",
|
||||
"包","诸","左","石",
|
||||
"崔","吉","钮","龚",
|
||||
"程","嵇","邢","滑",
|
||||
"裴","陆","荣","翁",
|
||||
"荀","羊","於","惠",
|
||||
"甄","曲","家","封",
|
||||
"芮","羿","储","靳",
|
||||
"汲","邴","糜","松",
|
||||
"井","段","富","巫",
|
||||
"乌","焦","巴","弓",
|
||||
"牧","隗","山","谷",
|
||||
"车","侯","宓","蓬",
|
||||
"全","郗","班","仰",
|
||||
"秋","仲","伊","宫",
|
||||
"宁","仇","栾","暴",
|
||||
"甘","钭","厉","戎",
|
||||
"祖","武","符","刘",
|
||||
"景","詹","束","龙",
|
||||
"叶","幸","司","韶",
|
||||
"郜","黎","蓟","薄",
|
||||
"印","宿","白","怀",
|
||||
"蒲","邰","从","鄂",
|
||||
"索","咸","籍","赖",
|
||||
"卓","蔺","屠","蒙",
|
||||
"池","乔","阴","鬱",
|
||||
"胥","能","苍","双",
|
||||
"闻","莘","党","翟",
|
||||
"谭","贡","劳","逄",
|
||||
"姬","申","扶","堵",
|
||||
"冉","宰","郦","雍",
|
||||
"郤","璩","桑","桂",
|
||||
"濮","牛","寿","通",
|
||||
"边","扈","燕","冀",
|
||||
"郏","浦","尚","农",
|
||||
"温","别","庄","晏",
|
||||
"柴","瞿","阎","充",
|
||||
"慕","连","茹","习",
|
||||
"宦","艾","鱼","容",
|
||||
"向","古","易","慎",
|
||||
"戈","廖","庾","终",
|
||||
"暨","居","衡","步",
|
||||
"都","耿","满","弘",
|
||||
"匡","国","文","寇",
|
||||
"广","禄","阙","东",
|
||||
"欧","殳","沃","利",
|
||||
"蔚","越","夔","隆",
|
||||
"师","巩","厍","聂",
|
||||
"晁","勾","敖","融",
|
||||
"冷","訾","辛","阚",
|
||||
"那","简","饶","空",
|
||||
"曾","母","沙","乜",
|
||||
"养","鞠","须","丰",
|
||||
"巢","关","蒯","相",
|
||||
"查","后","荆","红",
|
||||
"游","竺","权","逯",
|
||||
"盖","益","桓","公",
|
||||
"兰","原","乞","西","阿","肖","丑","位","曽","巨","德","代","圆","尉","仵","纳","仝","脱","丘","但","展","迪","付","覃","晗","特","隋","苑","奥","漆","谌","郄","练","扎","邝","渠","信","门","陳","化","原","密","泮","鹿","赫",
|
||||
"万俟","司马","上官","欧阳",
|
||||
"夏侯","诸葛","闻人","东方",
|
||||
"赫连","皇甫","尉迟","公羊",
|
||||
"澹台","公冶","宗政","濮阳",
|
||||
"淳于","单于","太叔","申屠",
|
||||
"公孙","仲孙","轩辕","令狐",
|
||||
"钟离","宇文","长孙","慕容",
|
||||
"鲜于","闾丘","司徒","司空",
|
||||
"亓官","司寇","仉督","子车",
|
||||
"颛孙","端木","巫马","公西",
|
||||
"漆雕","乐正","壤驷","公良",
|
||||
"拓跋","夹谷","宰父","榖梁",
|
||||
"晋","楚","闫","法","汝","鄢","涂","钦",
|
||||
"段干","百里","东郭","南门",
|
||||
"呼延","归","海","羊舌","微","生",
|
||||
"岳","帅","缑","亢","况","后","有","琴",
|
||||
"梁丘","左丘","东门","西门",
|
||||
"商","牟","佘","佴","伯","赏","南宫",
|
||||
"墨","哈","谯","笪","年","爱","阳","佟",
|
||||
"第五","言","福"])
|
||||
|
||||
def isit(n):return n.strip() in m
|
||||
|
||||
85
api/app/core/rag/nlp/synonym.py
Normal file
85
api/app/core/rag/nlp/synonym.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
from nltk.corpus import wordnet
|
||||
from app.core.rag.common.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self, redis=None):
|
||||
|
||||
self.lookup_num = 100000000
|
||||
self.load_tm = time.time() - 1000000
|
||||
self.dictionary = None
|
||||
path = os.path.join(get_project_base_directory(), "app/core/rag/res", "synonym.json")
|
||||
try:
|
||||
self.dictionary = json.load(open(path, 'r'))
|
||||
self.dictionary = { (k.lower() if isinstance(k, str) else k): v for k, v in self.dictionary.items() }
|
||||
except Exception:
|
||||
logging.warning("Missing synonym.json")
|
||||
self.dictionary = {}
|
||||
|
||||
if not redis:
|
||||
logging.warning(
|
||||
"Realtime synonym is disabled, since no redis connection.")
|
||||
if not len(self.dictionary.keys()):
|
||||
logging.warning("Fail to load synonym")
|
||||
|
||||
self.redis = redis
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
if not self.redis:
|
||||
return
|
||||
|
||||
if self.lookup_num < 100:
|
||||
return
|
||||
tm = time.time()
|
||||
if tm - self.load_tm < 3600:
|
||||
return
|
||||
|
||||
self.load_tm = time.time()
|
||||
self.lookup_num = 0
|
||||
d = self.redis.get("kevin_synonyms")
|
||||
if not d:
|
||||
return
|
||||
try:
|
||||
d = json.loads(d)
|
||||
self.dictionary = d
|
||||
except Exception as e:
|
||||
logging.error("Fail to load synonym!" + str(e))
|
||||
|
||||
|
||||
def lookup(self, tk, topn=8):
|
||||
if not tk or not isinstance(tk, str):
|
||||
return []
|
||||
|
||||
# 1) Check the custom dictionary first (both keys and tk are already lowercase)
|
||||
self.lookup_num += 1
|
||||
self.load()
|
||||
key = re.sub(r"[ \t]+", " ", tk.strip())
|
||||
res = self.dictionary.get(key, [])
|
||||
if isinstance(res, str):
|
||||
res = [res]
|
||||
if res: # Found in dictionary → return directly
|
||||
return res[:topn]
|
||||
|
||||
# 2) If not found and tk is purely alphabetical → fallback to WordNet
|
||||
if re.fullmatch(r"[a-z]+", tk):
|
||||
wn_set = {
|
||||
re.sub("_", " ", syn.name().split(".")[0])
|
||||
for syn in wordnet.synsets(tk)
|
||||
}
|
||||
wn_set.discard(tk) # Remove the original token itself
|
||||
wn_res = [t for t in wn_set if t]
|
||||
return wn_res[:topn]
|
||||
|
||||
# 3) Nothing found in either source
|
||||
return []
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dl = Dealer()
|
||||
print(dl.dictionary)
|
||||
228
api/app/core/rag/nlp/term_weight.py
Normal file
228
api/app/core/rag/nlp/term_weight.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import logging
|
||||
import math
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import numpy as np
|
||||
from . import rag_tokenizer
|
||||
from app.core.rag.common.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self):
|
||||
self.stop_words = set(["请问",
|
||||
"您",
|
||||
"你",
|
||||
"我",
|
||||
"他",
|
||||
"是",
|
||||
"的",
|
||||
"就",
|
||||
"有",
|
||||
"于",
|
||||
"及",
|
||||
"即",
|
||||
"在",
|
||||
"为",
|
||||
"最",
|
||||
"有",
|
||||
"从",
|
||||
"以",
|
||||
"了",
|
||||
"将",
|
||||
"与",
|
||||
"吗",
|
||||
"吧",
|
||||
"中",
|
||||
"#",
|
||||
"什么",
|
||||
"怎么",
|
||||
"哪个",
|
||||
"哪些",
|
||||
"啥",
|
||||
"相关"])
|
||||
|
||||
def load_dict(fnm):
|
||||
res = {}
|
||||
f = open(fnm, "r")
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
arr = line.replace("\n", "").split("\t")
|
||||
if len(arr) < 2:
|
||||
res[arr[0]] = 0
|
||||
else:
|
||||
res[arr[0]] = int(arr[1])
|
||||
|
||||
c = 0
|
||||
for _, v in res.items():
|
||||
c += v
|
||||
if c == 0:
|
||||
return set(res.keys())
|
||||
return res
|
||||
|
||||
fnm = os.path.join(get_project_base_directory(), "app/core/rag/res")
|
||||
self.ne, self.df = {}, {}
|
||||
try:
|
||||
self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
|
||||
except Exception:
|
||||
logging.warning("Load ner.json FAIL!")
|
||||
try:
|
||||
self.df = load_dict(os.path.join(fnm, "term.freq"))
|
||||
except Exception:
|
||||
logging.warning("Load term.freq FAIL!")
|
||||
|
||||
def pretoken(self, txt, num=False, stpwd=True):
|
||||
patt = [
|
||||
r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"
|
||||
]
|
||||
rewt = [
|
||||
]
|
||||
for p, r in rewt:
|
||||
txt = re.sub(p, r, txt)
|
||||
|
||||
res = []
|
||||
for t in rag_tokenizer.tokenize(txt).split():
|
||||
tk = t
|
||||
if (stpwd and tk in self.stop_words) or (
|
||||
re.match(r"[0-9]$", tk) and not num):
|
||||
continue
|
||||
for p in patt:
|
||||
if re.match(p, t):
|
||||
tk = "#"
|
||||
break
|
||||
#tk = re.sub(r"([\+\\-])", r"\\\1", tk)
|
||||
if tk != "#" and tk:
|
||||
res.append(tk)
|
||||
return res
|
||||
|
||||
def tokenMerge(self, tks):
|
||||
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||
|
||||
res, i = [], 0
|
||||
while i < len(tks):
|
||||
j = i
|
||||
if i == 0 and oneTerm(tks[i]) and len(
|
||||
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
|
||||
res.append(" ".join(tks[0:2]))
|
||||
i = 2
|
||||
continue
|
||||
|
||||
while j < len(
|
||||
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
|
||||
j += 1
|
||||
if j - i > 1:
|
||||
if j - i < 5:
|
||||
res.append(" ".join(tks[i:j]))
|
||||
i = j
|
||||
else:
|
||||
res.append(" ".join(tks[i:i + 2]))
|
||||
i = i + 2
|
||||
else:
|
||||
if len(tks[i]) > 0:
|
||||
res.append(tks[i])
|
||||
i += 1
|
||||
return [t for t in res if t]
|
||||
|
||||
def ner(self, t):
|
||||
if not self.ne:
|
||||
return ""
|
||||
res = self.ne.get(t, "")
|
||||
if res:
|
||||
return res
|
||||
|
||||
def split(self, txt):
|
||||
tks = []
|
||||
for t in re.sub(r"[ \t]+", " ", txt).split():
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
|
||||
re.match(r".*[a-zA-Z]$", t) and tks and \
|
||||
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
|
||||
tks[-1] = tks[-1] + " " + t
|
||||
else:
|
||||
tks.append(t)
|
||||
return tks
|
||||
|
||||
def weights(self, tks, preprocess=True):
|
||||
num_pattern = re.compile(r"[0-9,.]{2,}$")
|
||||
short_letter_pattern = re.compile(r"[a-z]{1,2}$")
|
||||
num_space_pattern = re.compile(r"[0-9. -]{2,}$")
|
||||
letter_pattern = re.compile(r"[a-z. -]+$")
|
||||
|
||||
def ner(t):
|
||||
if num_pattern.match(t):
|
||||
return 2
|
||||
if short_letter_pattern.match(t):
|
||||
return 0.01
|
||||
if not self.ne or t not in self.ne:
|
||||
return 1
|
||||
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
|
||||
"firstnm": 1}
|
||||
return m[self.ne[t]]
|
||||
|
||||
def postag(t):
|
||||
t = rag_tokenizer.tag(t)
|
||||
if t in set(["r", "c", "d"]):
|
||||
return 0.3
|
||||
if t in set(["ns", "nt"]):
|
||||
return 3
|
||||
if t in set(["n"]):
|
||||
return 2
|
||||
if re.match(r"[0-9-]+", t):
|
||||
return 2
|
||||
return 1
|
||||
|
||||
def freq(t):
|
||||
if num_space_pattern.match(t):
|
||||
return 3
|
||||
s = rag_tokenizer.freq(t)
|
||||
if not s and letter_pattern.match(t):
|
||||
return 300
|
||||
if not s:
|
||||
s = 0
|
||||
|
||||
if not s and len(t) >= 4:
|
||||
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
|
||||
if len(s) > 1:
|
||||
s = np.min([freq(tt) for tt in s]) / 6.
|
||||
else:
|
||||
s = 0
|
||||
|
||||
return max(s, 10)
|
||||
|
||||
def df(t):
|
||||
if num_space_pattern.match(t):
|
||||
return 5
|
||||
if t in self.df:
|
||||
return self.df[t] + 3
|
||||
elif letter_pattern.match(t):
|
||||
return 300
|
||||
elif len(t) >= 4:
|
||||
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
|
||||
if len(s) > 1:
|
||||
return max(3, np.min([df(tt) for tt in s]) / 6.)
|
||||
|
||||
return 3
|
||||
|
||||
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
|
||||
|
||||
tw = []
|
||||
if not preprocess:
|
||||
idf1 = np.array([idf(freq(t), 10000000) for t in tks])
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
np.array([ner(t) * postag(t) for t in tks])
|
||||
wts = [s for s in wts]
|
||||
tw = list(zip(tks, wts))
|
||||
else:
|
||||
for tk in tks:
|
||||
tt = self.tokenMerge(self.pretoken(tk, True))
|
||||
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
np.array([ner(t) * postag(t) for t in tt])
|
||||
wts = [s for s in wts]
|
||||
tw.extend(zip(tt, wts))
|
||||
|
||||
S = np.sum([s for _, s in tw])
|
||||
return [(t, s / S) for t, s in tw]
|
||||
Reference in New Issue
Block a user