Shawn001's picture
Upload 53 files
c2c125c
import math
import torch
import numpy as np
from megatron import get_args
def slidingcrops(img, mask):
# img: [b c h w]
# mask: [b h w]
args = get_args()
assert args.img_h == args.img_w
crop_size = args.img_h
stride = args.seg_stride
ignore_index = args.ignore_index
n, c, h, w = img.shape
assert h >= crop_size
assert w >= crop_size
long_size = max(h, w)
img_slices, mask_slices, slices_info = [], [], []
if long_size > crop_size:
assert stride <= crop_size
h_step_num = int(math.ceil((h - crop_size) / float(stride))) + 1
w_step_num = int(math.ceil((w - crop_size) / float(stride))) + 1
for yy in range(h_step_num):
for xx in range(w_step_num):
sy, sx = yy * stride, xx * stride
ey, ex = sy + crop_size, sx + crop_size
img_sub = img[:, :, sy: ey, sx: ex]
mask_sub = mask[:, sy: ey, sx: ex]
# padding
sub_h, sub_w = img_sub.shape[2:]
pad_h = max(crop_size - sub_h, 0)
pad_w = max(crop_size - sub_w, 0)
img_sub = torch.nn.functional.pad(img_sub, pad=(0, pad_w, 0, pad_h), value=ignore_index)
mask_sub = torch.nn.functional.pad(mask_sub, pad=(0, pad_w, 0, pad_h))
img_slices.append(img_sub)
mask_slices.append(mask_sub)
slices_info.append([sy, ey, sx, ex, sub_h, sub_w])
return torch.cat(img_slices), torch.cat(mask_slices), slices_info, (h, w)
else:
return img, mask, [[0, h, 0, w, h, w]], (h, w)
def slidingjoins(preds, probs, labels, slices_info, img_size):
args = get_args()
num_slices = len(slices_info)
if num_slices == 1:
return preds, labels
h, w = img_size
split_size = args.micro_batch_size
preds_split = torch.split(preds, split_size)
probs_split = torch.split(probs, split_size)
labels_split = torch.split(labels, split_size)
assert(len(preds_split) == num_slices)
total_max_probs = torch.zeros((split_size, h, w), dtype=torch.float, device='cuda')
total_preds = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda')
total_labels = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda')
for i in range(num_slices):
sy, ey, sx, ex, sub_h, sub_w = slices_info[i]
assert sy + sub_h <= h
assert sx + sub_w <= w
curr_max_probs = total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w]
curr_preds = total_preds[:, sy:sy + sub_h, sx:sx + sub_w]
local_max_probs = probs_split[i][:, :sub_h, : sub_w]
local_preds = preds_split[i][:, :sub_h, :sub_w]
result_max_probs = torch.maximum(curr_max_probs, local_max_probs)
result_preds = torch.where(curr_max_probs >= local_max_probs, curr_preds, local_preds)
total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w] = result_max_probs
total_preds[:, sy:sy + sub_h, sx:sx + sub_w] = result_preds
total_labels[:, sy:sy + sub_h, sx:sx + sub_w] = labels_split[i][0, :sub_h, :sub_w]
return total_preds, total_labels