Spaces:
Running
Running
import copy | |
import random | |
import numpy as np | |
from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode | |
class IGTRLabelEncode(BaseRecLabelEncode): | |
"""Convert between text-label and text-index.""" | |
def __init__(self, | |
max_text_length, | |
character_dict_path=None, | |
use_space_char=False, | |
k=1, | |
ch=False, | |
prompt_error=False, | |
**kwargs): | |
super(IGTRLabelEncode, | |
self).__init__(max_text_length, character_dict_path, | |
use_space_char) | |
self.ignore_index = self.dict['<pad>'] | |
self.k = k | |
self.prompt_error = prompt_error | |
self.ch = ch | |
rare_file = kwargs.get('rare_file', None) | |
siml_file = kwargs.get('siml_file', None) | |
siml_char_dict = {} | |
siml_char_list = [0 for _ in range(self.num_character)] | |
if siml_file is not None: | |
with open(siml_file, 'r') as f: | |
for lin in f.readlines(): | |
lin_s = lin.strip().split('\t') | |
char_siml = lin_s[0] | |
if char_siml in self.dict: | |
siml_list = [] | |
siml_prob = [] | |
for i in range(1, len(lin_s), 2): | |
c = lin_s[i] | |
prob = int(lin_s[i + 1]) | |
if c in self.dict and prob >= 1: | |
siml_list.append(self.dict[c]) | |
siml_prob.append(prob) | |
siml_prob = np.array(siml_prob, | |
dtype=np.float32) / sum(siml_prob) | |
siml_char_dict[self.dict[char_siml]] = [ | |
siml_list, siml_prob.tolist() | |
] | |
siml_char_list[self.dict[char_siml]] = 1 | |
self.siml_char_dict = siml_char_dict | |
self.siml_char_list = siml_char_list | |
rare_char_list = [0 for _ in range(self.num_character)] | |
if rare_file is not None: | |
with open(rare_file, 'r') as f: | |
for lin in f.readlines(): | |
lin_s = lin.strip().split('\t') | |
# print(lin_s) | |
char_rare = lin_s[0] | |
num_appear = int(lin_s[1]) | |
if char_rare in self.dict and num_appear < 1000: | |
rare_char_list[self.dict[char_rare]] = 1 | |
self.rare_char_list = rare_char_list # [self.dict[char] for char in rare_char_list] | |
def __call__(self, data): | |
text = data['label'] # coffee | |
encoder_result = self.encode(text) | |
if encoder_result is None: | |
return None | |
text, text_char_num, ques_list_s, prompt_list_s = encoder_result | |
if len(text) > self.max_text_len: | |
return None | |
data['length'] = np.array(len(text)) | |
text = [self.dict['<s>']] + text + [self.dict['</s>']] | |
text = text + [self.dict['<pad>'] | |
] * (self.max_text_len + 2 - len(text)) | |
data['label'] = np.array(text) # 6 | |
ques_len_list = [] | |
ques2_len_list = [] | |
prompt_len_list = [] | |
prompt_pos_idx_list = [] | |
prompt_char_idx_list = [] | |
ques_pos_idx_list = [] | |
ques1_answer_list = [] | |
ques2_char_idx_list = [] | |
ques2_answer_list = [] | |
ques4_char_num_list = [] | |
train_step = 0 | |
for prompt_list, ques_list in zip(prompt_list_s, ques_list_s): | |
prompt_len = len(prompt_list) + 1 | |
prompt_len_list.append(prompt_len) | |
prompt_list = np.array( | |
[[0, self.dict['<s>'], 0]] + prompt_list + | |
[[self.max_text_len + 2, self.dict['<pad>'], 0]] * | |
(self.max_text_len - len(prompt_list))) | |
prompt_pos_idx_list.append(prompt_list[:, 0]) | |
prompt_char_idx_list.append(prompt_list[:, 1]) | |
ques_len = len(ques_list) | |
ques_len_list.append(ques_len) | |
ques_list = np.array( | |
ques_list + [[self.max_text_len + 2, self.dict['<pad>'], 0]] * | |
(self.max_text_len + 1 - ques_len)) | |
ques_pos_idx_list.append(ques_list[:, 0]) | |
# what is the first and third char? | |
# Is the first character 't'? and Is the third character 'f'? | |
# How many 'c', 's' and 'f' are there in the text image? | |
ques1_answer_list.append(ques_list[:, 1]) | |
ques2_char_idx = copy.deepcopy(ques_list[:ques_len, :2]) | |
new_ques2_char_idx = [] | |
ques2_answer = [] | |
for q_2, ques2_idx in enumerate(ques2_char_idx.tolist()): | |
if (train_step == 2 or train_step == 3) and q_2 == ques_len - 1: | |
new_ques2_char_idx.append(ques2_idx) | |
ques2_answer.append(1) | |
continue | |
if ques2_idx[1] != self.dict['<pad>'] and random.random() > 0.5: | |
select_idx = random.randint(0, self.num_character - 3) | |
new_ques2_char_idx.append([ques2_idx[0], select_idx]) | |
if select_idx == ques2_idx[1]: | |
ques2_answer.append(1) | |
else: | |
ques2_answer.append(0) | |
if self.siml_char_list[ | |
ques2_idx[1]] == 1 and random.random() > 0.5: | |
select_idx_sim_list = random.sample( | |
self.siml_char_dict[ques2_idx[1]][0], | |
min(3, len(self.siml_char_dict[ques2_idx[1]][0])), | |
) | |
for select_idx in select_idx_sim_list: | |
new_ques2_char_idx.append( | |
[ques2_idx[0], select_idx]) | |
if select_idx == ques2_idx[1]: | |
ques2_answer.append(1) | |
else: | |
ques2_answer.append(0) | |
else: | |
new_ques2_char_idx.append(ques2_idx) | |
ques2_answer.append(1) | |
ques2_len_list.append(len(new_ques2_char_idx)) | |
ques2_char_idx_new = np.array( | |
new_ques2_char_idx + | |
[[self.max_text_len + 2, self.dict['<pad>']]] * | |
(self.max_text_len * 4 + 1 - len(new_ques2_char_idx))) | |
ques2_answer = np.array( | |
ques2_answer + [0] * | |
(self.max_text_len * 4 + 1 - len(ques2_answer))) | |
ques2_char_idx_list.append(ques2_char_idx_new) | |
ques2_answer_list.append(ques2_answer) | |
ques4_char_num_list.append(ques_list[:, 2]) | |
train_step += 1 | |
data['ques_len_list'] = np.array(ques_len_list, dtype=np.int64) | |
data['ques2_len_list'] = np.array(ques2_len_list, dtype=np.int64) | |
data['prompt_len_list'] = np.array(prompt_len_list, dtype=np.int64) | |
data['prompt_pos_idx_list'] = np.array(prompt_pos_idx_list, | |
dtype=np.int64) | |
data['prompt_char_idx_list'] = np.array(prompt_char_idx_list, | |
dtype=np.int64) | |
data['ques_pos_idx_list'] = np.array(ques_pos_idx_list, dtype=np.int64) | |
data['ques1_answer_list'] = np.array(ques1_answer_list, dtype=np.int64) | |
data['ques2_char_idx_list'] = np.array(ques2_char_idx_list, | |
dtype=np.int64) | |
data['ques2_answer_list'] = np.array(ques2_answer_list, | |
dtype=np.float32) | |
data['ques3_answer'] = np.array( | |
text_char_num, | |
dtype=np.int64) # np.array([1, 0, 2]) # answer 1, 0, 2 | |
data['ques4_char_num_list'] = np.array(ques4_char_num_list) | |
return data | |
def add_special_char(self, dict_character): | |
dict_character = ['</s>'] + dict_character + ['<s>'] + ['<pad>'] | |
self.num_character = len(dict_character) | |
return dict_character | |
def encode(self, text): | |
"""convert text-label into text-index. | |
input: | |
text: text labels of each image. [batch_size] | |
output: | |
text: concatenated text index for CTCLoss. | |
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] | |
length: length of each text. [batch_size] | |
""" | |
if len(text) == 0: | |
return None | |
if self.lower: | |
text = text.lower() | |
char_num = [0 for _ in range(self.num_character - 2)] | |
char_num[0] = 1 | |
text_list = [] | |
qa_text = [] | |
pos_i = 0 | |
rare_char_qa = [] | |
unrare_char_qa = [] | |
for char in text: | |
if char not in self.dict: | |
continue | |
char_id = self.dict[char] | |
text_list.append(char_id) | |
qa_text.append([pos_i + 1, char_id, char_num[char_id]]) | |
if self.rare_char_list[char_id] == 1: | |
rare_char_qa.append([pos_i + 1, char_id, char_num[char_id]]) | |
else: | |
unrare_char_qa.append([pos_i + 1, char_id, char_num[char_id]]) | |
char_num[char_id] += 1 | |
pos_i += 1 | |
if self.ch: | |
char_num_ch = [] | |
char_num_ch_none = [] | |
rare_char_num_ch_none = [] | |
for i, num in enumerate(char_num): | |
if self.rare_char_list[i] == 1: | |
rare_char_num_ch_none.append([i, num]) | |
if num > 0: | |
char_num_ch.append([i, num]) | |
else: | |
char_num_ch_none.append([i, 0]) | |
none_char_index = random.sample( | |
char_num_ch_none, | |
min(37 - len(char_num_ch), len(char_num_ch_none))) | |
if len(rare_char_num_ch_none) > 0: | |
none_rare_char_index = random.sample( | |
rare_char_num_ch_none, | |
min(40 - len(char_num_ch) - len(none_char_index), | |
len(rare_char_num_ch_none)), | |
) | |
char_num_ch = char_num_ch + none_char_index + none_rare_char_index | |
else: | |
char_num_ch = char_num_ch + none_char_index | |
char_num_ch.sort(key=lambda x: x[0]) | |
char_num = char_num_ch | |
len_ = len(text_list) | |
if len_ == 0: | |
return None | |
ques_list = [ | |
qa_text + [[pos_i + 1, self.dict['</s>'], 0]], | |
[[pos_i + 1, self.dict['</s>'], 0]], | |
] | |
prompt_list = [qa_text[len_:], qa_text] | |
if len_ == 1: | |
ques_list.append([[self.max_text_len + 1, self.dict['</s>'], 0]]) | |
prompt_list.append( | |
[[self.max_text_len + 2, self.dict['<pad>'], 0]] * 4 + qa_text) | |
for _ in range(1, self.k): | |
ques_list.append( | |
[[self.max_text_len + 2, self.dict['<pad>'], 0]]) | |
prompt_list.append(qa_text[1:]) | |
else: | |
next_id = random.sample(range(1, len_ + 1), 2) | |
for slice_id in next_id: | |
b_i = slice_id - 5 if slice_id - 5 > 0 else 0 | |
if slice_id == len_: | |
ques_list.append( | |
[[self.max_text_len + 1, self.dict['</s>'], 0]]) | |
else: | |
ques_list.append( | |
qa_text[slice_id:] + | |
[[self.max_text_len + 1, qa_text[slice_id][1], 0]]) | |
prompt_list.append( | |
[[self.max_text_len + 2, self.dict['<pad>'], 0]] * | |
(5 - slice_id + b_i) + qa_text[b_i:slice_id]) | |
shuffle_id1 = random.sample(range(1, len_), | |
2) if len_ > 2 else [1, 0] | |
for slice_id in shuffle_id1: | |
if slice_id == 0: | |
ques_list.append( | |
[[self.max_text_len + 2, self.dict['<pad>'], 0]]) | |
prompt_list.append(qa_text[:0]) | |
else: | |
ques_list.append(qa_text[slice_id:] + | |
[[pos_i + 1, self.dict['</s>'], 0]]) | |
prompt_list.append(qa_text[:slice_id]) | |
if len_ > 2: | |
shuffle_id2 = random.sample( | |
range(1, len_), | |
self.k - 4 if len_ - 1 > self.k - 4 else len_ - 1) | |
if self.k - 4 != len(shuffle_id2): | |
shuffle_id2 += random.sample(range(1, len_), | |
self.k - 4 - len(shuffle_id2)) | |
rare_slice_id = len(rare_char_qa) | |
unrare_slice_id = len(unrare_char_qa) | |
for slice_id in shuffle_id2: | |
random.shuffle(qa_text) | |
if len(rare_char_qa) > 0 and random.random() < 0.5: | |
ques_list.append(rare_char_qa[:rare_slice_id] + | |
unrare_char_qa[unrare_slice_id:] + | |
[[pos_i + 1, self.dict['</s>'], 0]]) | |
if len(unrare_char_qa[:unrare_slice_id]) > 0: | |
prompt_list1 = random.sample( | |
unrare_char_qa[:unrare_slice_id], | |
random.randint( | |
1, len(unrare_char_qa[:unrare_slice_id])) | |
if len(unrare_char_qa[:unrare_slice_id]) > 1 | |
else 1, | |
) | |
else: | |
prompt_list1 = [] | |
if len(rare_char_qa[rare_slice_id:]) > 0: | |
prompt_list2 = random.sample( | |
rare_char_qa[rare_slice_id:], | |
random.randint( | |
1, | |
len(rare_char_qa[rare_slice_id:]) | |
if len(rare_char_qa[rare_slice_id:]) > 1 | |
else 1, | |
), | |
) | |
else: | |
prompt_list2 = [] | |
prompt_list.append(prompt_list1 + prompt_list2) | |
random.shuffle(rare_char_qa) | |
random.shuffle(unrare_char_qa) | |
rare_slice_id = random.randint( | |
1, | |
len(rare_char_qa)) if len(rare_char_qa) > 1 else 1 | |
unrare_slice_id = random.randint( | |
1, len(unrare_char_qa) | |
) if len(unrare_char_qa) > 1 else 1 | |
else: | |
ques_list.append(qa_text[slice_id:] + | |
[[pos_i + 1, self.dict['</s>'], 0]]) | |
prompt_list.append(qa_text[:slice_id]) | |
else: | |
ques_list.append(qa_text[1:] + | |
[[pos_i + 1, self.dict['</s>'], 0]]) | |
prompt_list.append(qa_text[:1]) | |
ques_list.append(qa_text[:1] + | |
[[pos_i + 1, self.dict['</s>'], 0]]) | |
prompt_list.append(qa_text[1:]) | |
ques_list += [[[self.max_text_len + 2, self.dict['<pad>'], 0]] | |
] * (self.k - 6) | |
prompt_list += [qa_text[:0]] * (self.k - 6) | |
return text_list, char_num, ques_list, prompt_list | |