Spaces:
Running
Running
import torch | |
import numpy as np | |
import copy | |
def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs): | |
""" """ | |
pos_lists_, pos_masks_, label_lists_ = [], [], [] | |
img_bs = batch_size | |
ngpu = int(batch_size / img_bs) | |
img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy() | |
pos_lists_split, pos_masks_split, label_lists_split = [], [], [] | |
for i in range(ngpu): | |
pos_lists_split.append([]) | |
pos_masks_split.append([]) | |
label_lists_split.append([]) | |
for i in range(img_ids.shape[0]): | |
img_id = img_ids[i] | |
gpu_id = int(img_id / img_bs) | |
img_id = img_id % img_bs | |
pos_list = pos_lists[i].copy() | |
pos_list[:, 0] = img_id | |
pos_lists_split[gpu_id].append(pos_list) | |
pos_masks_split[gpu_id].append(pos_masks[i].copy()) | |
label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i])) | |
# repeat or delete | |
for i in range(ngpu): | |
vp_len = len(pos_lists_split[i]) | |
if vp_len <= tcl_bs: | |
for j in range(0, tcl_bs - vp_len): | |
pos_list = pos_lists_split[i][j].copy() | |
pos_lists_split[i].append(pos_list) | |
pos_mask = pos_masks_split[i][j].copy() | |
pos_masks_split[i].append(pos_mask) | |
label_list = copy.deepcopy(label_lists_split[i][j]) | |
label_lists_split[i].append(label_list) | |
else: | |
for j in range(0, vp_len - tcl_bs): | |
c_len = len(pos_lists_split[i]) | |
pop_id = np.random.permutation(c_len)[0] | |
pos_lists_split[i].pop(pop_id) | |
pos_masks_split[i].pop(pop_id) | |
label_lists_split[i].pop(pop_id) | |
# merge | |
for i in range(ngpu): | |
pos_lists_.extend(pos_lists_split[i]) | |
pos_masks_.extend(pos_masks_split[i]) | |
label_lists_.extend(label_lists_split[i]) | |
return pos_lists_, pos_masks_, label_lists_ | |
def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums, | |
pad_num, tcl_bs): | |
label_list = label_list.numpy() | |
batch, _, _, _ = label_list.shape | |
pos_list = pos_list.numpy() | |
pos_mask = pos_mask.numpy() | |
pos_list_t = [] | |
pos_mask_t = [] | |
label_list_t = [] | |
for i in range(batch): | |
for j in range(max_text_nums): | |
if pos_mask[i, j].any(): | |
pos_list_t.append(pos_list[i][j]) | |
pos_mask_t.append(pos_mask[i][j]) | |
label_list_t.append(label_list[i][j]) | |
pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t, | |
label_list_t, tcl_bs) | |
label = [] | |
tt = [l.tolist() for l in label_list] | |
for i in range(tcl_bs): | |
k = 0 | |
for j in range(max_text_length): | |
if tt[i][j][0] != pad_num: | |
k += 1 | |
else: | |
break | |
label.append(k) | |
label = torch.tensor(label) | |
label = label.long() | |
pos_list = torch.tensor(pos_list) | |
pos_mask = torch.tensor(pos_mask) | |
label_list = torch.squeeze(torch.tensor(label_list), dim=2) | |
label_list = label_list.int() | |
return pos_list, pos_mask, label_list, label | |