|
import math |
|
import torch |
|
import numpy as np |
|
from megatron import get_args |
|
|
|
def slidingcrops(img, mask): |
|
|
|
|
|
args = get_args() |
|
assert args.img_h == args.img_w |
|
crop_size = args.img_h |
|
stride = args.seg_stride |
|
ignore_index = args.ignore_index |
|
n, c, h, w = img.shape |
|
assert h >= crop_size |
|
assert w >= crop_size |
|
long_size = max(h, w) |
|
|
|
img_slices, mask_slices, slices_info = [], [], [] |
|
if long_size > crop_size: |
|
assert stride <= crop_size |
|
h_step_num = int(math.ceil((h - crop_size) / float(stride))) + 1 |
|
w_step_num = int(math.ceil((w - crop_size) / float(stride))) + 1 |
|
for yy in range(h_step_num): |
|
for xx in range(w_step_num): |
|
sy, sx = yy * stride, xx * stride |
|
ey, ex = sy + crop_size, sx + crop_size |
|
img_sub = img[:, :, sy: ey, sx: ex] |
|
mask_sub = mask[:, sy: ey, sx: ex] |
|
|
|
|
|
sub_h, sub_w = img_sub.shape[2:] |
|
pad_h = max(crop_size - sub_h, 0) |
|
pad_w = max(crop_size - sub_w, 0) |
|
img_sub = torch.nn.functional.pad(img_sub, pad=(0, pad_w, 0, pad_h), value=ignore_index) |
|
mask_sub = torch.nn.functional.pad(mask_sub, pad=(0, pad_w, 0, pad_h)) |
|
|
|
img_slices.append(img_sub) |
|
mask_slices.append(mask_sub) |
|
slices_info.append([sy, ey, sx, ex, sub_h, sub_w]) |
|
|
|
return torch.cat(img_slices), torch.cat(mask_slices), slices_info, (h, w) |
|
else: |
|
return img, mask, [[0, h, 0, w, h, w]], (h, w) |
|
|
|
|
|
def slidingjoins(preds, probs, labels, slices_info, img_size): |
|
args = get_args() |
|
num_slices = len(slices_info) |
|
|
|
if num_slices == 1: |
|
return preds, labels |
|
|
|
h, w = img_size |
|
split_size = args.micro_batch_size |
|
|
|
preds_split = torch.split(preds, split_size) |
|
probs_split = torch.split(probs, split_size) |
|
labels_split = torch.split(labels, split_size) |
|
|
|
assert(len(preds_split) == num_slices) |
|
|
|
total_max_probs = torch.zeros((split_size, h, w), dtype=torch.float, device='cuda') |
|
total_preds = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda') |
|
total_labels = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda') |
|
|
|
for i in range(num_slices): |
|
sy, ey, sx, ex, sub_h, sub_w = slices_info[i] |
|
assert sy + sub_h <= h |
|
assert sx + sub_w <= w |
|
curr_max_probs = total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w] |
|
curr_preds = total_preds[:, sy:sy + sub_h, sx:sx + sub_w] |
|
|
|
local_max_probs = probs_split[i][:, :sub_h, : sub_w] |
|
local_preds = preds_split[i][:, :sub_h, :sub_w] |
|
|
|
result_max_probs = torch.maximum(curr_max_probs, local_max_probs) |
|
result_preds = torch.where(curr_max_probs >= local_max_probs, curr_preds, local_preds) |
|
|
|
total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w] = result_max_probs |
|
total_preds[:, sy:sy + sub_h, sx:sx + sub_w] = result_preds |
|
total_labels[:, sy:sy + sub_h, sx:sx + sub_w] = labels_split[i][0, :sub_h, :sub_w] |
|
|
|
return total_preds, total_labels |
|
|
|
|