import math import torch import numpy as np from megatron import get_args def slidingcrops(img, mask): # img: [b c h w] # mask: [b h w] args = get_args() assert args.img_h == args.img_w crop_size = args.img_h stride = args.seg_stride ignore_index = args.ignore_index n, c, h, w = img.shape assert h >= crop_size assert w >= crop_size long_size = max(h, w) img_slices, mask_slices, slices_info = [], [], [] if long_size > crop_size: assert stride <= crop_size h_step_num = int(math.ceil((h - crop_size) / float(stride))) + 1 w_step_num = int(math.ceil((w - crop_size) / float(stride))) + 1 for yy in range(h_step_num): for xx in range(w_step_num): sy, sx = yy * stride, xx * stride ey, ex = sy + crop_size, sx + crop_size img_sub = img[:, :, sy: ey, sx: ex] mask_sub = mask[:, sy: ey, sx: ex] # padding sub_h, sub_w = img_sub.shape[2:] pad_h = max(crop_size - sub_h, 0) pad_w = max(crop_size - sub_w, 0) img_sub = torch.nn.functional.pad(img_sub, pad=(0, pad_w, 0, pad_h), value=ignore_index) mask_sub = torch.nn.functional.pad(mask_sub, pad=(0, pad_w, 0, pad_h)) img_slices.append(img_sub) mask_slices.append(mask_sub) slices_info.append([sy, ey, sx, ex, sub_h, sub_w]) return torch.cat(img_slices), torch.cat(mask_slices), slices_info, (h, w) else: return img, mask, [[0, h, 0, w, h, w]], (h, w) def slidingjoins(preds, probs, labels, slices_info, img_size): args = get_args() num_slices = len(slices_info) if num_slices == 1: return preds, labels h, w = img_size split_size = args.micro_batch_size preds_split = torch.split(preds, split_size) probs_split = torch.split(probs, split_size) labels_split = torch.split(labels, split_size) assert(len(preds_split) == num_slices) total_max_probs = torch.zeros((split_size, h, w), dtype=torch.float, device='cuda') total_preds = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda') total_labels = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda') for i in range(num_slices): sy, ey, sx, ex, sub_h, sub_w = slices_info[i] assert sy + sub_h <= h assert sx + sub_w <= w curr_max_probs = total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w] curr_preds = total_preds[:, sy:sy + sub_h, sx:sx + sub_w] local_max_probs = probs_split[i][:, :sub_h, : sub_w] local_preds = preds_split[i][:, :sub_h, :sub_w] result_max_probs = torch.maximum(curr_max_probs, local_max_probs) result_preds = torch.where(curr_max_probs >= local_max_probs, curr_preds, local_preds) total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w] = result_max_probs total_preds[:, sy:sy + sub_h, sx:sx + sub_w] = result_preds total_labels[:, sy:sy + sub_h, sx:sx + sub_w] = labels_split[i][0, :sub_h, :sub_w] return total_preds, total_labels