|
import yaml |
|
import torch |
|
import random |
|
import numpy as np |
|
import os |
|
import sys |
|
import matplotlib.pyplot as plt |
|
from einops import repeat |
|
import cv2 |
|
import time |
|
import torch.nn.functional as F |
|
|
|
|
|
__all__ = ["decode_mask_to_onehot", |
|
"encode_onehot_to_mask", |
|
'Logger', |
|
'get_coords_grid', |
|
'get_coords_grid_float', |
|
'draw_bboxes', |
|
'Infos', |
|
'inv_normalize_img', |
|
'make_numpy_img', |
|
'get_metrics' |
|
] |
|
|
|
|
|
class Infos(object): |
|
def __init__(self, phase, class_names=None): |
|
assert phase in ['od'], "Error in Infos" |
|
self.phase = phase |
|
self.class_names = class_names |
|
self.register() |
|
self.pattern = 'train' |
|
self.epoch_id = 0 |
|
self.max_epoch = 0 |
|
self.batch_id = 0 |
|
self.batch_num = 0 |
|
self.lr = 0 |
|
self.fps_data_load = 0 |
|
self.fps = 0 |
|
self.val_metric = 0 |
|
|
|
|
|
|
|
|
|
def set_epoch_training_time(self, data): |
|
self.epoch_training_time = data |
|
|
|
def set_pattern(self, data): |
|
self.pattern = data |
|
def set_epoch_id(self, data): |
|
self.epoch_id = data |
|
def set_max_epoch(self, data): |
|
self.max_epoch = data |
|
def set_batch_id(self, data): |
|
self.batch_id = data |
|
def set_batch_num(self, data): |
|
self.batch_num = data |
|
def set_lr(self, data): |
|
self.lr = data |
|
def set_fps_data_load(self, data): |
|
self.fps_data_load = data |
|
def set_fps(self, data): |
|
self.fps = data |
|
def clear_cache(self): |
|
self.register() |
|
|
|
def get_val_metric(self): |
|
return self.val_metric |
|
|
|
def cal_metrics(self): |
|
if self.phase == 'od': |
|
coco_api_gt = COCO() |
|
coco_api_gt.dataset['images'] = [] |
|
coco_api_gt.dataset['annotations'] = [] |
|
ann_id = 0 |
|
for i, targets_per_image in enumerate(self.result_all['target_all']): |
|
for j in range(targets_per_image.shape[0]): |
|
coco_api_gt.dataset['images'].append({'id': i}) |
|
coco_api_gt.dataset['annotations'].append({ |
|
'image_id': i, |
|
"category_id": int(targets_per_image[j, 0]), |
|
"bbox": np.hstack([targets_per_image[j, 1:3], targets_per_image[j, 3:5] - targets_per_image[j, 1:3]]), |
|
"area": np.prod(targets_per_image[j, 3:5] - targets_per_image[j, 1:3]), |
|
"id": ann_id, |
|
"iscrowd": 0 |
|
}) |
|
ann_id += 1 |
|
coco_api_gt.dataset['categories'] = [{"id": i, "supercategory": c, "name": c} for i, c in |
|
enumerate(self.class_names)] |
|
coco_api_gt.createIndex() |
|
|
|
coco_api_pred = COCO() |
|
coco_api_pred.dataset['images'] = [] |
|
coco_api_pred.dataset['annotations'] = [] |
|
ann_id = 0 |
|
for i, preds_per_image in enumerate(self.result_all['pred_all']): |
|
for j in range(preds_per_image.shape[0]): |
|
coco_api_pred.dataset['images'].append({'id': i}) |
|
coco_api_pred.dataset['annotations'].append({ |
|
'image_id': i, |
|
"category_id": int(preds_per_image[j, 0]), |
|
'score': preds_per_image[j, 1], |
|
"bbox": np.hstack( |
|
[preds_per_image[j, 2:4], preds_per_image[j, 4:6] - preds_per_image[j, 2:4]]), |
|
"area": np.prod(preds_per_image[j, 4:6] - preds_per_image[j, 2:4]), |
|
"id": ann_id, |
|
"iscrowd": 0 |
|
}) |
|
ann_id += 1 |
|
coco_api_pred.dataset['categories'] = [{"id": i, "supercategory": c, "name": c} for i, c in |
|
enumerate(self.class_names)] |
|
coco_api_pred.createIndex() |
|
|
|
coco_eval = COCOeval(coco_api_gt, coco_api_pred, "bbox") |
|
coco_eval.params.imgIds = coco_api_gt.getImgIds() |
|
coco_eval.evaluate() |
|
coco_eval.accumulate() |
|
self.metrics = coco_eval.summarize() |
|
self.val_metric = self.metrics[1] |
|
|
|
def print_epoch_state_infos(self, logger): |
|
infos_str = 'Pattern: %s Epoch [%d,%d], time: %d loss: %.4f' % \ |
|
(self.pattern, self.epoch_id, self.max_epoch, self.epoch_training_time, np.mean(self.loss_all['loss'])) |
|
logger.write(infos_str + '\n') |
|
time_start = time.time() |
|
self.cal_metrics() |
|
time_end = time.time() |
|
logger.write('Pattern: %s Epoch Eval_time: %d\n' % (self.pattern, (time_end - time_start))) |
|
|
|
if self.phase == 'od': |
|
titleStr = 6 * ['Average Precision'] + 6 * ['Average Recall'] |
|
typeStr = 6 * ['(AP)'] + 6 * ['(AR)'] |
|
iouStr = 12 * ['0.50:0.95'] |
|
iouStr[1] = '0.50' |
|
iouStr[2] = '0.75' |
|
areaRng = 3 * ['all'] + ['small', 'medium', 'large'] + 3 * ['all'] + ['small', 'medium', 'large'] |
|
maxDets = 6 * [100] + [1, 10, 100] + 3 * [100] |
|
for i in range(12): |
|
infos_str = '{:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}\n' |
|
logger.write(infos_str.format(titleStr[i], typeStr[i], iouStr[i], areaRng[i], maxDets[i], self.metrics[i])) |
|
|
|
|
|
def save_epoch_state_infos(self, writer): |
|
iter = self.epoch_id |
|
keys = [ |
|
'AP_m_all_100', |
|
'AP_50_all_100', |
|
'AP_75_all_100', |
|
'AP_m_small_100', |
|
'AP_m_medium_100', |
|
'AP_m_large_100', |
|
'AR_m_all_1', |
|
'AR_m_all_10', |
|
'AR_m_all_100', |
|
'AR_m_small_100', |
|
'AR_m_medium_100', |
|
'AR_m_large_100', |
|
] |
|
for i, key in enumerate(keys): |
|
writer.add_scalar(f'%s/epoch/%s' % (self.pattern, key), self.metrics[i], iter) |
|
|
|
def print_batch_state_infos(self, logger): |
|
infos_str = 'Pattern: %s [%d,%d][%d,%d], lr: %5f, fps_data_load: %.2f, fps: %.2f' % \ |
|
(self.pattern, self.epoch_id, self.max_epoch, self.batch_id, |
|
self.batch_num, self.lr, self.fps_data_load, self.fps) |
|
|
|
infos_str += ', loss: %.4f' % self.loss_all['loss'][-1] |
|
logger.write(infos_str + '\n') |
|
|
|
def save_batch_state_infos(self, writer): |
|
iter = self.epoch_id * self.batch_num + self.batch_id |
|
writer.add_scalar('%s/lr' % self.pattern, self.lr, iter) |
|
for key, value in self.loss_all.items(): |
|
writer.add_scalar(f'%s/%s' % (self.pattern, key), value[-1], iter) |
|
|
|
def save_results(self, img_batch, prior_mean, prior_std, vis_dir, *args, **kwargs): |
|
batch_size = img_batch.size(0) |
|
k = np.clip(int(0.3 * batch_size), a_min=1, a_max=batch_size) |
|
ids = np.random.choice(range(batch_size), k, replace=False) |
|
for img_id in ids: |
|
img = img_batch[img_id].detach().cpu() |
|
pred = self.result_all['pred_all'][img_id - batch_size] |
|
target = self.result_all['target_all'][img_id - batch_size] |
|
|
|
img = make_numpy_img(inv_normalize_img(img, prior_mean, prior_std)) |
|
pred_draw = draw_bboxes(img, pred, self.class_names, (255, 0, 0)) |
|
target_draw = draw_bboxes(img, target, self.class_names, (0, 255, 0)) |
|
|
|
|
|
|
|
vis = np.concatenate([img/255., pred_draw/255., target_draw/255.], axis=0) |
|
vis = np.clip(vis, a_min=0, a_max=1) |
|
file_name = os.path.join(vis_dir, self.pattern, f'{self.epoch_id}_{self.batch_id}_{img_id}.png') |
|
plt.imsave(file_name, vis) |
|
|
|
def register(self): |
|
self.is_registered_result = False |
|
self.result_all = {} |
|
|
|
self.is_registered_loss = False |
|
self.loss_all = {} |
|
|
|
def register_result(self, data: dict): |
|
for key in data.keys(): |
|
self.result_all[key] = [] |
|
self.is_registered_result = True |
|
|
|
def append_result(self, data: dict): |
|
if not self.is_registered_result: |
|
self.register_result(data) |
|
for key, value in data.items(): |
|
self.result_all[key] += value |
|
|
|
def register_loss(self, data: dict): |
|
for key in data.keys(): |
|
self.loss_all[key] = [] |
|
self.is_registered_loss = True |
|
|
|
def append_loss(self, data: dict): |
|
if not self.is_registered_loss: |
|
self.register_loss(data) |
|
for key, value in data.items(): |
|
self.loss_all[key].append(value.detach().cpu().numpy()) |
|
|
|
|
|
|
|
def draw_bboxes(img, bboxes, color=(255, 0, 0), class_names=None, is_show_score=True): |
|
''' |
|
Args: |
|
img: |
|
bboxes: [n, 5], class_idx, l, t, r, b |
|
[n, 6], class_idx, score, l, t, r, b |
|
Returns: |
|
''' |
|
assert img is not None, "In draw_bboxes, img is None" |
|
if torch.is_tensor(img): |
|
img = img.cpu().numpy() |
|
img = img.astype(np.uint8).copy() |
|
|
|
if torch.is_tensor(bboxes): |
|
bboxes = bboxes.cpu().numpy() |
|
for bbox in bboxes: |
|
if class_names: |
|
class_name = class_names[int(bbox[0])] |
|
bbox_coordinate = bbox[1:] |
|
if len(bbox) == 6: |
|
score = bbox[1] |
|
bbox_coordinate = bbox[2:] |
|
bbox_coordinate = bbox_coordinate.astype(np.int) |
|
if is_show_score: |
|
cv2.rectangle(img, pt1=tuple(bbox_coordinate[0:2] - np.array([2, 15])), |
|
pt2=tuple(bbox_coordinate[0:2] + np.array([15, 1])), color=(0, 0, 255), thickness=-1) |
|
if len(bbox) == 6: |
|
cv2.putText(img, text='%s:%.2f' % (class_name, score), |
|
org=tuple(bbox_coordinate[0:2] - np.array([1, 7])), fontFace=cv2.FONT_HERSHEY_SIMPLEX, |
|
fontScale=0.2, color=(255, 255, 255), thickness=1) |
|
else: |
|
cv2.putText(img, text='%s' % class_name, |
|
org=tuple(bbox_coordinate[0:2] - np.array([1, 7])), fontFace=cv2.FONT_HERSHEY_SIMPLEX, |
|
fontScale=0.2, color=(255, 255, 255), thickness=1) |
|
cv2.rectangle(img, pt1=tuple(bbox_coordinate[0:2]), pt2=tuple(bbox_coordinate[2:4]), color=color, thickness=2) |
|
return img |
|
|
|
|
|
def get_coords_grid(h_end, w_end, h_start=0, w_start=0, h_steps=None, w_steps=None, is_normalize=False): |
|
if h_steps is None: |
|
h_steps = int(h_end - h_start) + 1 |
|
if w_steps is None: |
|
w_steps = int(w_end - w_start) + 1 |
|
|
|
y = torch.linspace(h_start, h_end, h_steps) |
|
x = torch.linspace(w_start, w_end, w_steps) |
|
if is_normalize: |
|
y = y / h_end |
|
x = x / w_end |
|
coords = torch.meshgrid(y, x) |
|
coords = torch.stack(coords[::-1], dim=0) |
|
return coords |
|
|
|
|
|
def get_coords_grid_float(ht, wd, scale, is_normalize=False): |
|
y = torch.linspace(0, scale, ht + 2) |
|
x = torch.linspace(0, scale, wd + 2) |
|
if is_normalize: |
|
y = y/scale |
|
x = x/scale |
|
coords = torch.meshgrid(y[1:-1], x[1:-1]) |
|
coords = torch.stack(coords[::-1], dim=0) |
|
return coords |
|
|
|
|
|
def get_coords_vector_float(len, scale, is_normalize=False): |
|
x = torch.linspace(0, scale, len+2) |
|
if is_normalize: |
|
x = x/scale |
|
coords = torch.meshgrid(x[1:-1], torch.tensor([0.])) |
|
coords = torch.stack(coords[::-1], dim=0) |
|
return coords |
|
|
|
|
|
class Logger(object): |
|
def __init__(self, filename="Default.log", is_terminal_show=True): |
|
self.is_terminal_show = is_terminal_show |
|
if self.is_terminal_show: |
|
self.terminal = sys.stdout |
|
self.log = open(filename, "a") |
|
|
|
def write(self, message): |
|
if self.is_terminal_show: |
|
self.terminal.write(message) |
|
self.log.write(message) |
|
self.flush() |
|
|
|
def flush(self): |
|
if self.is_terminal_show: |
|
self.terminal.flush() |
|
self.log.flush() |
|
|
|
|
|
class ParamsParser: |
|
def __init__(self, project_file): |
|
self.params = yaml.safe_load(open(project_file).read()) |
|
|
|
def __getattr__(self, item): |
|
return self.params.get(item, None) |
|
|
|
|
|
def get_all_dict(dict_infos: dict) -> dict: |
|
return_dict = {} |
|
for key, value in dict_infos.items(): |
|
if not isinstance(value, dict): |
|
return_dict[key] = value |
|
else: |
|
return_dict = dict(return_dict.items(), **get_all_dict(value)) |
|
return return_dict |
|
|
|
|
|
def make_numpy_img(tensor_data): |
|
if len(tensor_data.shape) == 2: |
|
tensor_data = tensor_data.unsqueeze(2) |
|
tensor_data = torch.cat((tensor_data, tensor_data, tensor_data), dim=2) |
|
elif tensor_data.size(0) == 1: |
|
tensor_data = tensor_data.permute((1, 2, 0)) |
|
tensor_data = torch.cat((tensor_data, tensor_data, tensor_data), dim=2) |
|
elif tensor_data.size(0) == 3: |
|
tensor_data = tensor_data.permute((1, 2, 0)) |
|
elif tensor_data.size(2) == 3: |
|
pass |
|
else: |
|
raise Exception('tensor_data apply to make_numpy_img error') |
|
vis_img = tensor_data.detach().cpu().numpy() |
|
|
|
return vis_img |
|
|
|
|
|
def print_infos(logger, writer, infos: dict): |
|
keys = list(infos.keys()) |
|
values = list(infos.values()) |
|
infos_str = 'Pattern: %s [%d,%d][%d,%d], lr: %5f, fps_data_load: %.2f, fps: %.2f' % tuple(values[:8]) |
|
if len(values) > 8: |
|
extra_infos = [f', {x}: {y:.4f}' for x, y in zip(keys[8:], values[8:])] |
|
infos_str = infos_str + ''.join(extra_infos) |
|
|
|
logger.write(infos_str + '\n') |
|
|
|
writer.add_scalar('%s/lr' % infos['pattern'], infos['lr'], |
|
infos['epoch_id'] * infos['batch_num'] + infos['batch_id']) |
|
for key, value in zip(keys[8:], values[8:]): |
|
writer.add_scalar(f'%s/%s' % (infos['pattern'], key), value, |
|
infos['epoch_id'] * infos['batch_num'] + infos['batch_id']) |
|
|
|
|
|
def invert_affine(origin_imgs, preds, pattern='train'): |
|
if pattern == 'val': |
|
for i in range(len(preds)): |
|
if len(preds[i]['rois']) == 0: |
|
continue |
|
else: |
|
old_h, old_w, _ = origin_imgs[i].shape |
|
preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / (512 / old_w) |
|
preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / (512 / old_h) |
|
return preds |
|
|
|
|
|
def save_output_infos(input, output, vis_dir, pattern, epoch_id, batch_id): |
|
flows, pf1s, pf2s = output |
|
k = np.clip(int(0.2 * len(flows[0])), a_min=2, a_max=len(flows[0])) |
|
ids = np.random.choice(range(len(flows[0])), k, replace=False) |
|
for img_id in ids: |
|
img1, img2 = input['ori_img1'][img_id:img_id+1].to(flows[0].device), input['ori_img2'][img_id:img_id+1].to(flows[0].device) |
|
|
|
flow = flows[0][img_id:img_id+1] |
|
warps = flow_to_warp(flow) |
|
|
|
warped_img2 = resample(img2, warps) |
|
|
|
ori_img1 = make_numpy_img(img1[0]) / 255. |
|
ori_img2 = make_numpy_img(img2[0]) / 255. |
|
warped_img2 = make_numpy_img(warped_img2[0]) / 255. |
|
flow_amplitude = torch.sqrt(flow[0, 0:1, ...] ** 2 + flow[0, 1:2, ...] ** 2) |
|
flow_amplitude = make_numpy_img(flow_amplitude) |
|
flow_amplitude = (flow_amplitude - np.min(flow_amplitude)) / (np.max(flow_amplitude) - np.min(flow_amplitude) + 1e-10) |
|
u = make_numpy_img(flow[0, 0:1, ...]) |
|
v = make_numpy_img(flow[0, 1:2, ...]) |
|
|
|
vis = np.concatenate([ori_img1, ori_img2, warped_img2, flow_amplitude], axis=0) |
|
vis = np.clip(vis, a_min=0, a_max=1) |
|
file_name = os.path.join(vis_dir, pattern, str(epoch_id) + '_' + str(batch_id) + '.jpg') |
|
plt.imsave(file_name, vis) |
|
|
|
|
|
def inv_normalize_img(img, prior_mean=[0, 0, 0], prior_std=[1, 1, 1]): |
|
prior_mean = torch.tensor(prior_mean, dtype=torch.float).to(img.device).view(img.size(0), 1, 1) |
|
prior_std = torch.tensor(prior_std, dtype=torch.float).to(img.device).view(img.size(0), 1, 1) |
|
img = img * prior_std + prior_mean |
|
img = img * 255. |
|
img = torch.clamp(img, min=0, max=255) |
|
return img |
|
|
|
|
|
def save_seg_output_infos(input, output, vis_dir, pattern, epoch_id, batch_id, prior_mean, prior_std): |
|
pred_label = torch.argmax(output, 1) |
|
k = np.clip(int(0.2 * len(pred_label)), a_min=1, a_max=len(pred_label[0])) |
|
ids = np.random.choice(range(len(pred_label)), k, replace=False) |
|
for img_id in ids: |
|
img = input['img'][img_id].to(pred_label.device) |
|
target = input['label'][img_id].to(pred_label.device) |
|
|
|
img = make_numpy_img(inv_normalize_img(img, prior_mean, prior_std)) / 255. |
|
target = make_numpy_img(encode_onehot_to_mask(target)) |
|
pred = make_numpy_img(pred_label[img_id]) |
|
|
|
vis = np.concatenate([img, pred, target], axis=0) |
|
vis = np.clip(vis, a_min=0, a_max=1) |
|
file_name = os.path.join(vis_dir, pattern, str(epoch_id) + '_' + str(batch_id) + '.jpg') |
|
plt.imsave(file_name, vis) |
|
|
|
|
|
def set_requires_grad(nets, requires_grad=False): |
|
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations |
|
Parameters: |
|
nets (network list) -- a list of networks |
|
requires_grad (bool) -- whether the networks require gradients or not |
|
""" |
|
if not isinstance(nets, list): |
|
nets = [nets] |
|
for net in nets: |
|
if net is not None: |
|
for param in net.parameters(): |
|
param.requires_grad = requires_grad |
|
|
|
|
|
def boolean_string(s): |
|
if s not in {'False', 'True'}: |
|
raise ValueError('Not a valid boolean string') |
|
return s == 'True' |
|
|
|
|
|
def cpt_pxl_cls_acc(pred_idx, target): |
|
pred_idx = torch.reshape(pred_idx, [-1]) |
|
target = torch.reshape(target, [-1]) |
|
return torch.mean((pred_idx.int() == target.int()).float()) |
|
|
|
|
|
def cpt_batch_psnr(img, img_gt, PIXEL_MAX): |
|
mse = torch.mean((img - img_gt) ** 2, dim=[1, 2, 3]) |
|
psnr = 20 * torch.log10(PIXEL_MAX / torch.sqrt(mse)) |
|
return torch.mean(psnr) |
|
|
|
|
|
def cpt_psnr(img, img_gt, PIXEL_MAX): |
|
mse = np.mean((img - img_gt) ** 2) |
|
psnr = 20 * np.log10(PIXEL_MAX / np.sqrt(mse)) |
|
return psnr |
|
|
|
|
|
def cpt_rgb_ssim(img, img_gt): |
|
img = clip_01(img) |
|
img_gt = clip_01(img_gt) |
|
SSIM = 0 |
|
for i in range(3): |
|
tmp = img[:, :, i] |
|
tmp_gt = img_gt[:, :, i] |
|
ssim = sk_cpt_ssim(tmp, tmp_gt) |
|
SSIM = SSIM + ssim |
|
return SSIM / 3.0 |
|
|
|
|
|
def cpt_ssim(img, img_gt): |
|
img = clip_01(img) |
|
img_gt = clip_01(img_gt) |
|
return sk_cpt_ssim(img, img_gt) |
|
|
|
|
|
def decode_mask_to_onehot(mask, n_class): |
|
''' |
|
mask : BxWxH or WxH |
|
n_class : n |
|
return : BxnxWxH or nxWxH |
|
''' |
|
assert len(mask.shape) in [2, 3], "decode_mask_to_onehot error!" |
|
if len(mask.shape) == 2: |
|
mask = mask.unsqueeze(0) |
|
onehot = torch.zeros((mask.size(0), n_class, mask.size(1), mask.size(2))).to(mask.device) |
|
for i in range(n_class): |
|
onehot[:, i, ...] = mask == i |
|
if len(mask.shape) == 2: |
|
onehot = onehot.squeeze(0) |
|
return onehot |
|
|
|
|
|
def encode_onehot_to_mask(onehot): |
|
''' |
|
onehot: tensor, BxnxWxH or nxWxH |
|
output: tensor, BxWxH or WxH |
|
''' |
|
assert len(onehot.shape) in [3, 4], "encode_onehot_to_mask error!" |
|
mask = torch.argmax(onehot, dim=len(onehot.shape)-3) |
|
return mask |
|
|
|
|
|
def decode(pred, target=None, *args, **kwargs): |
|
""" |
|
|
|
Args: |
|
phase: 'od' |
|
pred: big_cls_1(0), big_reg_1, small_cls_1(2), small_reg_1, big_cls_2(4), big_reg_2, small_cls_2(6), small_reg_2 |
|
target: [[n,5], [n,5]] list of tensor |
|
|
|
Returns: |
|
|
|
""" |
|
phase = kwargs['phase'] |
|
img_size = kwargs['img_size'] |
|
if phase == 'od': |
|
prior_box_wh = kwargs['prior_box_wh'] |
|
conf_thres = kwargs['conf_thres'] |
|
iou_thres = kwargs['iou_thres'] |
|
conf_type = kwargs['conf_type'] |
|
pred_conf_32_2 = F.softmax(pred[4], dim=1)[:, 1, ...] |
|
pred_conf_64_2 = F.softmax(pred[6], dim=1)[:, 1, ...] |
|
obj_mask_32_2 = pred_conf_32_2 > conf_thres |
|
obj_mask_64_2 = pred_conf_64_2 > conf_thres |
|
|
|
pre_loc_32_2 = pred[1] + pred[5] |
|
pre_loc_32_2[:, 0::2, ...] *= prior_box_wh[0] |
|
pre_loc_32_2[:, 1::2, ...] *= prior_box_wh[1] |
|
x_y_grid = get_coords_grid(31, 31, 0, 0) |
|
x_y_grid *= 8 |
|
x_y_grid = torch.cat([x_y_grid, x_y_grid], dim=0) |
|
pre_loc_32_2 += x_y_grid.to(pre_loc_32_2.device) |
|
|
|
pre_loc_64_2 = pred[3] + pred[7] |
|
pre_loc_64_2[:, 0::2, ...] *= prior_box_wh[0] |
|
pre_loc_64_2[:, 1::2, ...] *= prior_box_wh[1] |
|
x_y_grid_2 = get_coords_grid(63, 63, 0, 0) |
|
x_y_grid_2 *= 4 |
|
x_y_grid_2 = torch.cat([x_y_grid_2, x_y_grid_2], dim=0) |
|
pre_loc_64_2 += x_y_grid_2.to(pre_loc_32_2.device) |
|
|
|
pred_all = [] |
|
for i in range(pre_loc_32_2.size(0)): |
|
score_32 = pred_conf_32_2[i][obj_mask_32_2[i]] |
|
score_64 = pred_conf_64_2[i][obj_mask_64_2[i]] |
|
|
|
loc_32 = pre_loc_32_2[i].permute((1, 2, 0))[obj_mask_32_2[i]] |
|
loc_64 = pre_loc_64_2[i].permute((1, 2, 0))[obj_mask_64_2[i]] |
|
|
|
score_list = torch.cat((score_32, score_64), dim=0).detach().cpu().numpy() |
|
boxes_list = torch.cat((loc_32, loc_64), dim=0).detach().cpu().numpy() |
|
boxes_list[:, 0::2] /= img_size[0] |
|
boxes_list[:, 1::2] /= img_size[1] |
|
label_list = np.ones_like(score_list) |
|
|
|
boxes_list = boxes_list[:150, :] |
|
score_list = score_list[:150] |
|
label_list = label_list[:150] |
|
boxes, scores, labels = weighted_boxes_fusion([boxes_list], [score_list], [label_list], weights=None, |
|
iou_thr=iou_thres, conf_type=conf_type) |
|
boxes[:, 0::2] *= img_size[0] |
|
boxes[:, 1::2] *= img_size[1] |
|
pred_boxes = np.concatenate((labels.reshape(-1, 1), scores.reshape(-1, 1), boxes), axis=1) |
|
pred_all.append(pred_boxes) |
|
if target is not None: |
|
target_all = [x.cpu().numpy() for x in target] |
|
else: |
|
target_all = None |
|
return {"pred_all": pred_all, "target_all": target_all} |
|
|
|
|
|
|
|
def get_metrics(phase, pred, target): |
|
|
|
''' |
|
pred: logits, tensor, nBatch*nClass*W*H |
|
target: labels, tensor, nBatch*nClass*W*H |
|
''' |
|
if phase == 'seg': |
|
pred = torch.argmax(pred.detach(), dim=1) |
|
pred = decode_mask_to_onehot(pred, target.size(1)) |
|
|
|
gt_pos_sum = torch.sum(target == 1, dim=(0, 2, 3)) |
|
|
|
pred_pos_sum = torch.sum(pred == 1, dim=(0, 2, 3)) |
|
|
|
true_pos_sum = torch.sum((target == 1) * (pred == 1), dim=(0, 2, 3)) |
|
|
|
precision = true_pos_sum / (pred_pos_sum + 1e-15) |
|
|
|
recall = true_pos_sum / (gt_pos_sum + 1e-15) |
|
|
|
IoU = true_pos_sum / (pred_pos_sum + gt_pos_sum - true_pos_sum + 1e-15) |
|
|
|
OA = 1 - (pred_pos_sum + gt_pos_sum - 2 * true_pos_sum) / torch.sum(target >= 0, dim=(0, 2, 3)) |
|
|
|
F1_score = 2 * precision * recall / (precision + recall + 1e-15) |
|
return IoU, OA, F1_score |
|
|
|
|