|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Vision-classification finetuning/evaluation.""" |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from functools import partial |
|
from megatron import get_args, get_timers |
|
from megatron import mpu, print_rank_0, print_rank_last |
|
from tasks.vision.finetune_utils import finetune |
|
from tasks.vision.finetune_utils import build_data_loader |
|
from megatron.utils import average_losses_across_data_parallel_group |
|
from megatron.schedules import get_forward_backward_func |
|
from tasks.vision.segmentation.metrics import CFMatrix |
|
from tasks.vision.segmentation.data import build_train_valid_datasets |
|
from tasks.vision.segmentation.seg_models import SetrSegmentationModel |
|
from tasks.vision.segmentation.utils import slidingcrops, slidingjoins |
|
|
|
def segmentation(): |
|
def train_valid_datasets_provider(): |
|
"""Build train and validation dataset.""" |
|
args = get_args() |
|
|
|
train_ds, valid_ds = build_train_valid_datasets( |
|
data_path=args.data_path, |
|
image_size=(args.img_h, args.img_w) |
|
|
|
) |
|
return train_ds, valid_ds |
|
|
|
def model_provider(pre_process=True, post_process=True): |
|
"""Build the model.""" |
|
args = get_args() |
|
|
|
return SetrSegmentationModel(num_classes=args.num_classes, |
|
pre_process=pre_process, |
|
post_process=post_process) |
|
|
|
def process_batch(batch): |
|
"""Process batch and produce inputs for the model.""" |
|
images = batch[0].cuda().contiguous() |
|
masks = batch[1].cuda().contiguous() |
|
return images, masks |
|
|
|
def calculate_weight(masks, num_classes): |
|
bins = torch.histc(masks, bins=num_classes, min=0.0, max=num_classes) |
|
hist_norm = bins.float()/bins.sum() |
|
hist = ((bins != 0).float() * (1. - hist_norm)) + 1.0 |
|
return hist |
|
|
|
def cross_entropy_loss_func(images, masks, output_tensor, non_loss_data=False): |
|
args = get_args() |
|
ignore_index = args.ignore_index |
|
color_table = args.color_table |
|
weight = calculate_weight(masks, args.num_classes) |
|
logits = output_tensor.contiguous().float() |
|
loss = F.cross_entropy(logits, masks, weight=weight, ignore_index=ignore_index) |
|
|
|
if not non_loss_data: |
|
|
|
averaged_loss = average_losses_across_data_parallel_group([loss]) |
|
|
|
return loss, {'lm loss': averaged_loss[0]} |
|
else: |
|
seg_mask = logits.argmax(dim=1) |
|
output_mask = F.embedding(seg_mask, color_table).permute(0, 3, 1, 2) |
|
gt_mask = F.embedding(masks, color_table).permute(0, 3, 1, 2) |
|
return torch.cat((images, output_mask, gt_mask), dim=2), loss |
|
|
|
def _cross_entropy_forward_step(batch, model): |
|
"""Simple forward step with cross-entropy loss.""" |
|
args = get_args() |
|
timers = get_timers() |
|
|
|
|
|
timers("batch generator").start() |
|
import types |
|
if isinstance(batch, types.GeneratorType): |
|
batch_ = next(batch) |
|
else: |
|
batch_ = batch |
|
images, masks = process_batch(batch_) |
|
timers("batch generator").stop() |
|
|
|
|
|
if not model.training: |
|
images, masks, _, _ = slidingcrops(images, masks) |
|
|
|
|
|
if not model.training: |
|
output_tensor = torch.cat([model(image) for image in torch.split(images, args.micro_batch_size)]) |
|
else: |
|
output_tensor = model(images) |
|
|
|
return output_tensor, partial(cross_entropy_loss_func, images, masks) |
|
|
|
def calculate_correct_answers(model, dataloader, epoch): |
|
"""Calculate correct over total answers""" |
|
|
|
forward_backward_func = get_forward_backward_func() |
|
for m in model: |
|
m.eval() |
|
|
|
def loss_func(labels, slices_info, img_size, output_tensor): |
|
args = get_args() |
|
logits = output_tensor |
|
|
|
loss_dict = {} |
|
|
|
probs = logits.contiguous().float().softmax(dim=1) |
|
max_probs, preds = torch.max(probs, 1) |
|
preds = preds.int() |
|
preds, labels = slidingjoins(preds, max_probs, labels, slices_info, img_size) |
|
_, performs = CFMatrix()(preds, labels, args.ignore_index) |
|
|
|
loss_dict['performs'] = performs |
|
return 0, loss_dict |
|
|
|
|
|
def correct_answers_forward_step(batch, model): |
|
args = get_args() |
|
try: |
|
batch_ = next(batch) |
|
except BaseException: |
|
batch_ = batch |
|
images, labels = process_batch(batch_) |
|
|
|
assert not model.training |
|
images, labels, slices_info, img_size = slidingcrops(images, labels) |
|
|
|
output_tensor = torch.cat([model(image) for image in torch.split(images, args.micro_batch_size)]) |
|
|
|
return output_tensor, partial(loss_func, labels, slices_info, img_size) |
|
|
|
with torch.no_grad(): |
|
|
|
performs = None |
|
for _, batch in enumerate(dataloader): |
|
loss_dicts = forward_backward_func(correct_answers_forward_step, |
|
batch, model, |
|
optimizer=None, |
|
timers=None, |
|
forward_only=True) |
|
for loss_dict in loss_dicts: |
|
if performs is None: |
|
performs = loss_dict['performs'] |
|
else: |
|
performs += loss_dict['performs'] |
|
|
|
for m in model: |
|
m.train() |
|
|
|
if mpu.is_pipeline_last_stage(): |
|
torch.distributed.all_reduce(performs, |
|
group=mpu.get_data_parallel_group()) |
|
|
|
|
|
true_positive = performs[:, 0] |
|
false_positive = performs[:, 1] |
|
false_negative = performs[:, 3] |
|
|
|
iou = true_positive / (true_positive + false_positive + false_negative) |
|
miou = iou[~torch.isnan(iou)].mean() |
|
|
|
return iou.tolist(), miou.item() |
|
|
|
def accuracy_func_provider(): |
|
"""Provide function that calculates accuracies.""" |
|
args = get_args() |
|
|
|
train_ds, valid_ds = build_train_valid_datasets( |
|
data_path=args.data_path, |
|
image_size=(args.img_h, args.img_w) |
|
) |
|
dataloader = build_data_loader( |
|
valid_ds, |
|
args.micro_batch_size, |
|
num_workers=args.num_workers, |
|
drop_last=(mpu.get_data_parallel_world_size() > 1), |
|
shuffle=False |
|
) |
|
|
|
def metrics_func(model, epoch): |
|
print_rank_0("calculating metrics ...") |
|
iou, miou = calculate_correct_answers(model, dataloader, epoch) |
|
print_rank_last( |
|
" >> |epoch: {}| overall: iou = {}," |
|
"miou = {:.4f} %".format(epoch, iou, miou*100.0) |
|
) |
|
return metrics_func |
|
|
|
def dump_output_data(data, iteration, writer): |
|
for (output_tb, loss) in data: |
|
|
|
|
|
writer.add_images("image-outputseg-realseg", output_tb, |
|
global_step=None, walltime=None, |
|
dataformats='NCHW') |
|
|
|
"""Finetune/evaluate.""" |
|
finetune( |
|
train_valid_datasets_provider, |
|
model_provider, |
|
forward_step=_cross_entropy_forward_step, |
|
process_non_loss_data_func=dump_output_data, |
|
end_of_epoch_callback_provider=accuracy_func_provider, |
|
) |
|
|
|
|
|
def main(): |
|
segmentation() |
|
|
|
|