Spaces:
Paused
Paused
# | |
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import random | |
from collections import Counter | |
from rag.utils import num_tokens_from_string | |
from . import rag_tokenizer | |
import re | |
import copy | |
import roman_numbers as r | |
from word2number import w2n | |
from cn2an import cn2an | |
from PIL import Image | |
all_codecs = [ | |
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs', | |
'cp037', 'cp273', 'cp424', 'cp437', | |
'cp500', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp856', 'cp857', | |
'cp858', 'cp860', 'cp861', 'cp862', 'cp863', 'cp864', 'cp865', 'cp866', 'cp869', | |
'cp874', 'cp875', 'cp932', 'cp949', 'cp950', 'cp1006', 'cp1026', 'cp1125', | |
'cp1140', 'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', 'cp1256', | |
'cp1257', 'cp1258', 'euc_jp', 'euc_jis_2004', 'euc_jisx0213', 'euc_kr', | |
'gb2312', 'gb18030', 'hz', 'iso2022_jp', 'iso2022_jp_1', 'iso2022_jp_2', | |
'iso2022_jp_2004', 'iso2022_jp_3', 'iso2022_jp_ext', 'iso2022_kr', 'latin_1', | |
'iso8859_2', 'iso8859_3', 'iso8859_4', 'iso8859_5', 'iso8859_6', 'iso8859_7', | |
'iso8859_8', 'iso8859_9', 'iso8859_10', 'iso8859_11', 'iso8859_13', | |
'iso8859_14', 'iso8859_15', 'iso8859_16', 'johab', 'koi8_r', 'koi8_t', 'koi8_u', | |
'kz1048', 'mac_cyrillic', 'mac_greek', 'mac_iceland', 'mac_latin2', 'mac_roman', | |
'mac_turkish', 'ptcp154', 'shift_jis', 'shift_jis_2004', 'shift_jisx0213', | |
'utf_32', 'utf_32_be', 'utf_32_le''utf_16_be', 'utf_16_le', 'utf_7' | |
] | |
def find_codec(blob): | |
global all_codecs | |
for c in all_codecs: | |
try: | |
blob[:1024].decode(c) | |
return c | |
except Exception as e: | |
pass | |
try: | |
blob.decode(c) | |
return c | |
except Exception as e: | |
pass | |
return "utf-8" | |
QUESTION_PATTERN = [ | |
r"第([零一二三四五六七八九十百0-9]+)问", | |
r"第([零一二三四五六七八九十百0-9]+)条", | |
r"[\((]([零一二三四五六七八九十百]+)[\))]", | |
r"第([0-9]+)问", | |
r"第([0-9]+)条", | |
r"([0-9]{1,2})[\. 、]", | |
r"([零一二三四五六七八九十百]+)[ 、]", | |
r"[\((]([0-9]{1,2})[\))]", | |
r"QUESTION (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)", | |
r"QUESTION (I+V?|VI*|XI|IX|X)", | |
r"QUESTION ([0-9]+)", | |
] | |
def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list): | |
section, last_section = box['text'], last_box['text'] | |
q_reg = r'(\w|\W)*?(?:?|\?|\n|$)+' | |
full_reg = reg + q_reg | |
has_bull = re.match(full_reg, section) | |
index_str = None | |
if has_bull: | |
if 'x0' not in last_box: | |
last_box['x0'] = box['x0'] | |
if 'top' not in last_box: | |
last_box['top'] = box['top'] | |
if last_bull and box['x0']-last_box['x0']>10: | |
return None, last_index | |
if not last_bull and box['x0'] >= last_box['x0'] and box['top'] - last_box['top'] < 20: | |
return None, last_index | |
avg_bull_x0 = 0 | |
if bull_x0_list: | |
avg_bull_x0 = sum(bull_x0_list) / len(bull_x0_list) | |
else: | |
avg_bull_x0 = box['x0'] | |
if box['x0'] - avg_bull_x0 > 10: | |
return None, last_index | |
index_str = has_bull.group(1) | |
index = index_int(index_str) | |
if last_section[-1] == ':' or last_section[-1] == ':': | |
return None, last_index | |
if not last_index or index >= last_index: | |
bull_x0_list.append(box['x0']) | |
return has_bull, index | |
if section[-1] == '?' or section[-1] == '?': | |
bull_x0_list.append(box['x0']) | |
return has_bull, index | |
if box['layout_type'] == 'title': | |
bull_x0_list.append(box['x0']) | |
return has_bull, index | |
pure_section = section.lstrip(re.match(reg, section).group()).lower() | |
ask_reg = r'(what|when|where|how|why|which|who|whose|为什么|为啥|哪)' | |
if re.match(ask_reg, pure_section): | |
bull_x0_list.append(box['x0']) | |
return has_bull, index | |
return None, last_index | |
def index_int(index_str): | |
res = -1 | |
try: | |
res=int(index_str) | |
except ValueError: | |
try: | |
res=w2n.word_to_num(index_str) | |
except ValueError: | |
try: | |
res = cn2an(index_str) | |
except ValueError: | |
try: | |
res = r.number(index_str) | |
except ValueError: | |
return -1 | |
return res | |
def qbullets_category(sections): | |
global QUESTION_PATTERN | |
hits = [0] * len(QUESTION_PATTERN) | |
for i, pro in enumerate(QUESTION_PATTERN): | |
for sec in sections: | |
if re.match(pro, sec) and not not_bullet(sec): | |
hits[i] += 1 | |
break | |
maxium = 0 | |
res = -1 | |
for i, h in enumerate(hits): | |
if h <= maxium: | |
continue | |
res = i | |
maxium = h | |
return res, QUESTION_PATTERN[res] | |
BULLET_PATTERN = [[ | |
r"第[零一二三四五六七八九十百0-9]+(分?编|部分)", | |
r"第[零一二三四五六七八九十百0-9]+章", | |
r"第[零一二三四五六七八九十百0-9]+节", | |
r"第[零一二三四五六七八九十百0-9]+条", | |
r"[\((][零一二三四五六七八九十百]+[\))]", | |
], [ | |
r"第[0-9]+章", | |
r"第[0-9]+节", | |
r"[0-9]{,2}[\. 、]", | |
r"[0-9]{,2}\.[0-9]{,2}[^a-zA-Z/%~-]", | |
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", | |
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", | |
], [ | |
r"第[零一二三四五六七八九十百0-9]+章", | |
r"第[零一二三四五六七八九十百0-9]+节", | |
r"[零一二三四五六七八九十百]+[ 、]", | |
r"[\((][零一二三四五六七八九十百]+[\))]", | |
r"[\((][0-9]{,2}[\))]", | |
], [ | |
r"PART (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)", | |
r"Chapter (I+V?|VI*|XI|IX|X)", | |
r"Section [0-9]+", | |
r"Article [0-9]+" | |
] | |
] | |
def random_choices(arr, k): | |
k = min(len(arr), k) | |
return random.choices(arr, k=k) | |
def not_bullet(line): | |
patt = [ | |
r"0", r"[0-9]+ +[0-9~个只-]", r"[0-9]+\.{2,}" | |
] | |
return any([re.match(r, line) for r in patt]) | |
def bullets_category(sections): | |
global BULLET_PATTERN | |
hits = [0] * len(BULLET_PATTERN) | |
for i, pro in enumerate(BULLET_PATTERN): | |
for sec in sections: | |
for p in pro: | |
if re.match(p, sec) and not not_bullet(sec): | |
hits[i] += 1 | |
break | |
maxium = 0 | |
res = -1 | |
for i, h in enumerate(hits): | |
if h <= maxium: | |
continue | |
res = i | |
maxium = h | |
return res | |
def is_english(texts): | |
eng = 0 | |
if not texts: return False | |
for t in texts: | |
if re.match(r"[a-zA-Z]{2,}", t.strip()): | |
eng += 1 | |
if eng / len(texts) > 0.8: | |
return True | |
return False | |
def tokenize(d, t, eng): | |
d["content_with_weight"] = t | |
t = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", t) | |
d["content_ltks"] = rag_tokenizer.tokenize(t) | |
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | |
def tokenize_chunks(chunks, doc, eng, pdf_parser=None): | |
res = [] | |
# wrap up as es documents | |
for ck in chunks: | |
if len(ck.strip()) == 0:continue | |
print("--", ck) | |
d = copy.deepcopy(doc) | |
if pdf_parser: | |
try: | |
d["image"], poss = pdf_parser.crop(ck, need_position=True) | |
add_positions(d, poss) | |
ck = pdf_parser.remove_tag(ck) | |
except NotImplementedError as e: | |
pass | |
tokenize(d, ck, eng) | |
res.append(d) | |
return res | |
def tokenize_chunks_docx(chunks, doc, eng, images): | |
res = [] | |
# wrap up as es documents | |
for ck, image in zip(chunks, images): | |
if len(ck.strip()) == 0:continue | |
print("--", ck) | |
d = copy.deepcopy(doc) | |
d["image"] = image | |
tokenize(d, ck, eng) | |
res.append(d) | |
return res | |
def tokenize_table(tbls, doc, eng, batch_size=10): | |
res = [] | |
# add tables | |
for (img, rows), poss in tbls: | |
if not rows: | |
continue | |
if isinstance(rows, str): | |
d = copy.deepcopy(doc) | |
tokenize(d, rows, eng) | |
d["content_with_weight"] = rows | |
if img: d["image"] = img | |
if poss: add_positions(d, poss) | |
res.append(d) | |
continue | |
de = "; " if eng else "; " | |
for i in range(0, len(rows), batch_size): | |
d = copy.deepcopy(doc) | |
r = de.join(rows[i:i + batch_size]) | |
tokenize(d, r, eng) | |
d["image"] = img | |
add_positions(d, poss) | |
res.append(d) | |
return res | |
def add_positions(d, poss): | |
if not poss: | |
return | |
d["page_num_int"] = [] | |
d["position_int"] = [] | |
d["top_int"] = [] | |
for pn, left, right, top, bottom in poss: | |
d["page_num_int"].append(int(pn + 1)) | |
d["top_int"].append(int(top)) | |
d["position_int"].append((int(pn + 1), int(left), int(right), int(top), int(bottom))) | |
def remove_contents_table(sections, eng=False): | |
i = 0 | |
while i < len(sections): | |
def get(i): | |
nonlocal sections | |
return (sections[i] if isinstance(sections[i], | |
type("")) else sections[i][0]).strip() | |
if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$", | |
re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], re.IGNORECASE)): | |
i += 1 | |
continue | |
sections.pop(i) | |
if i >= len(sections): | |
break | |
prefix = get(i)[:3] if not eng else " ".join(get(i).split(" ")[:2]) | |
while not prefix: | |
sections.pop(i) | |
if i >= len(sections): | |
break | |
prefix = get(i)[:3] if not eng else " ".join(get(i).split(" ")[:2]) | |
sections.pop(i) | |
if i >= len(sections) or not prefix: | |
break | |
for j in range(i, min(i + 128, len(sections))): | |
if not re.match(prefix, get(j)): | |
continue | |
for _ in range(i, j): | |
sections.pop(i) | |
break | |
def make_colon_as_title(sections): | |
if not sections: | |
return [] | |
if isinstance(sections[0], type("")): | |
return sections | |
i = 0 | |
while i < len(sections): | |
txt, layout = sections[i] | |
i += 1 | |
txt = txt.split("@")[0].strip() | |
if not txt: | |
continue | |
if txt[-1] not in "::": | |
continue | |
txt = txt[::-1] | |
arr = re.split(r"([。?!!?;;]| \.)", txt) | |
if len(arr) < 2 or len(arr[1]) < 32: | |
continue | |
sections.insert(i - 1, (arr[0][::-1], "title")) | |
i += 1 | |
def title_frequency(bull, sections): | |
bullets_size = len(BULLET_PATTERN[bull]) | |
levels = [bullets_size+1 for _ in range(len(sections))] | |
if not sections or bull < 0: | |
return bullets_size+1, levels | |
for i, (txt, layout) in enumerate(sections): | |
for j, p in enumerate(BULLET_PATTERN[bull]): | |
if re.match(p, txt.strip()) and not not_bullet(txt): | |
levels[i] = j | |
break | |
else: | |
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]): | |
levels[i] = bullets_size | |
most_level = bullets_size+1 | |
for l, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1): | |
if l <= bullets_size: | |
most_level = l | |
break | |
return most_level, levels | |
def not_title(txt): | |
if re.match(r"第[零一二三四五六七八九十百0-9]+条", txt): | |
return False | |
if len(txt.split(" ")) > 12 or (txt.find(" ") < 0 and len(txt) >= 32): | |
return True | |
return re.search(r"[,;,。;!!]", txt) | |
def hierarchical_merge(bull, sections, depth): | |
if not sections or bull < 0: | |
return [] | |
if isinstance(sections[0], type("")): | |
sections = [(s, "") for s in sections] | |
sections = [(t, o) for t, o in sections if | |
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())] | |
bullets_size = len(BULLET_PATTERN[bull]) | |
levels = [[] for _ in range(bullets_size + 2)] | |
for i, (txt, layout) in enumerate(sections): | |
for j, p in enumerate(BULLET_PATTERN[bull]): | |
if re.match(p, txt.strip()): | |
levels[j].append(i) | |
break | |
else: | |
if re.search(r"(title|head)", layout) and not not_title(txt): | |
levels[bullets_size].append(i) | |
else: | |
levels[bullets_size + 1].append(i) | |
sections = [t for t, _ in sections] | |
# for s in sections: print("--", s) | |
def binary_search(arr, target): | |
if not arr: | |
return -1 | |
if target > arr[-1]: | |
return len(arr) - 1 | |
if target < arr[0]: | |
return -1 | |
s, e = 0, len(arr) | |
while e - s > 1: | |
i = (e + s) // 2 | |
if target > arr[i]: | |
s = i | |
continue | |
elif target < arr[i]: | |
e = i | |
continue | |
else: | |
assert False | |
return s | |
cks = [] | |
readed = [False] * len(sections) | |
levels = levels[::-1] | |
for i, arr in enumerate(levels[:depth]): | |
for j in arr: | |
if readed[j]: | |
continue | |
readed[j] = True | |
cks.append([j]) | |
if i + 1 == len(levels) - 1: | |
continue | |
for ii in range(i + 1, len(levels)): | |
jj = binary_search(levels[ii], j) | |
if jj < 0: | |
continue | |
if jj > cks[-1][-1]: | |
cks[-1].pop(-1) | |
cks[-1].append(levels[ii][jj]) | |
for ii in cks[-1]: | |
readed[ii] = True | |
if not cks: | |
return cks | |
for i in range(len(cks)): | |
cks[i] = [sections[j] for j in cks[i][::-1]] | |
print("--------------\n", "\n* ".join(cks[i])) | |
res = [[]] | |
num = [0] | |
for ck in cks: | |
if len(ck) == 1: | |
n = num_tokens_from_string(re.sub(r"@@[0-9]+.*", "", ck[0])) | |
if n + num[-1] < 218: | |
res[-1].append(ck[0]) | |
num[-1] += n | |
continue | |
res.append(ck) | |
num.append(n) | |
continue | |
res.append(ck) | |
num.append(218) | |
return res | |
def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"): | |
if not sections: | |
return [] | |
if isinstance(sections[0], type("")): | |
sections = [(s, "") for s in sections] | |
cks = [""] | |
tk_nums = [0] | |
def add_chunk(t, pos): | |
nonlocal cks, tk_nums, delimiter | |
tnum = num_tokens_from_string(t) | |
if not pos: pos = "" | |
if tnum < 8: | |
pos = "" | |
# Ensure that the length of the merged chunk does not exceed chunk_token_num | |
if tk_nums[-1] > chunk_token_num: | |
if t.find(pos) < 0: | |
t += pos | |
cks.append(t) | |
tk_nums.append(tnum) | |
else: | |
if cks[-1].find(pos) < 0: | |
t += pos | |
cks[-1] += t | |
tk_nums[-1] += tnum | |
for sec, pos in sections: | |
add_chunk(sec, pos) | |
continue | |
s, e = 0, 1 | |
while e < len(sec): | |
if sec[e] in delimiter: | |
add_chunk(sec[s: e + 1], pos) | |
s = e + 1 | |
e = s + 1 | |
else: | |
e += 1 | |
if s < e: | |
add_chunk(sec[s: e], pos) | |
return cks | |
def docx_question_level(p, bull = -1): | |
txt = re.sub(r"\u3000", " ", p.text).strip() | |
if p.style.name.startswith('Heading'): | |
return int(p.style.name.split(' ')[-1]), txt | |
else: | |
if bull < 0: | |
return 0, txt | |
for j, title in enumerate(BULLET_PATTERN[bull]): | |
if re.match(title, txt): | |
return j+1, txt | |
return len(BULLET_PATTERN[bull]), txt | |
def concat_img(img1, img2): | |
if img1 and not img2: | |
return img1 | |
if not img1 and img2: | |
return img2 | |
if not img1 and not img2: | |
return None | |
width1, height1 = img1.size | |
width2, height2 = img2.size | |
new_width = max(width1, width2) | |
new_height = height1 + height2 | |
new_image = Image.new('RGB', (new_width, new_height)) | |
new_image.paste(img1, (0, 0)) | |
new_image.paste(img2, (0, height1)) | |
return new_image | |
def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"): | |
if not sections: | |
return [], [] | |
cks = [""] | |
images = [None] | |
tk_nums = [0] | |
def add_chunk(t, image, pos=""): | |
nonlocal cks, tk_nums, delimiter | |
tnum = num_tokens_from_string(t) | |
if tnum < 8: | |
pos = "" | |
if tk_nums[-1] > chunk_token_num: | |
if t.find(pos) < 0: | |
t += pos | |
cks.append(t) | |
images.append(image) | |
tk_nums.append(tnum) | |
else: | |
if cks[-1].find(pos) < 0: | |
t += pos | |
cks[-1] += t | |
images[-1] = concat_img(images[-1], image) | |
tk_nums[-1] += tnum | |
for sec, image in sections: | |
add_chunk(sec, image, '') | |
return cks, images | |
def keyword_extraction(chat_mdl, content): | |
prompt = """ | |
You're a question analyzer. | |
1. Please give me the most important keyword/phrase of this question. | |
Answer format: (in language of user's question) | |
- keyword: | |
""" | |
kwd = chat_mdl.chat(prompt, [{"role": "user", "content": content}], {"temperature": 0.2}) | |
if isinstance(kwd, tuple): return kwd[0] | |
return kwd | |