Shawn001's picture
Upload 53 files
c2c125c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
from megatron import mpu, print_rank_0, print_rank_last
from tasks.vision.finetune_utils import finetune
from tasks.vision.finetune_utils import build_data_loader
from megatron.utils import average_losses_across_data_parallel_group
from megatron.schedules import get_forward_backward_func
from tasks.vision.segmentation.metrics import CFMatrix
from tasks.vision.segmentation.data import build_train_valid_datasets
from tasks.vision.segmentation.seg_models import SetrSegmentationModel
from tasks.vision.segmentation.utils import slidingcrops, slidingjoins
def segmentation():
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
return train_ds, valid_ds
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
return SetrSegmentationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
def process_batch(batch):
"""Process batch and produce inputs for the model."""
images = batch[0].cuda().contiguous()
masks = batch[1].cuda().contiguous()
return images, masks
def calculate_weight(masks, num_classes):
bins = torch.histc(masks, bins=num_classes, min=0.0, max=num_classes)
hist_norm = bins.float()/bins.sum()
hist = ((bins != 0).float() * (1. - hist_norm)) + 1.0
return hist
def cross_entropy_loss_func(images, masks, output_tensor, non_loss_data=False):
args = get_args()
ignore_index = args.ignore_index
color_table = args.color_table
weight = calculate_weight(masks, args.num_classes)
logits = output_tensor.contiguous().float()
loss = F.cross_entropy(logits, masks, weight=weight, ignore_index=ignore_index)
if not non_loss_data:
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
else:
seg_mask = logits.argmax(dim=1)
output_mask = F.embedding(seg_mask, color_table).permute(0, 3, 1, 2)
gt_mask = F.embedding(masks, color_table).permute(0, 3, 1, 2)
return torch.cat((images, output_mask, gt_mask), dim=2), loss
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
args = get_args()
timers = get_timers()
# Get the batch.
timers("batch generator").start()
import types
if isinstance(batch, types.GeneratorType):
batch_ = next(batch)
else:
batch_ = batch
images, masks = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
if not model.training:
images, masks, _, _ = slidingcrops(images, masks)
#print_rank_0("images size = {}".format(images.size()))
if not model.training:
output_tensor = torch.cat([model(image) for image in torch.split(images, args.micro_batch_size)])
else:
output_tensor = model(images)
return output_tensor, partial(cross_entropy_loss_func, images, masks)
def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers"""
forward_backward_func = get_forward_backward_func()
for m in model:
m.eval()
def loss_func(labels, slices_info, img_size, output_tensor):
args = get_args()
logits = output_tensor
loss_dict = {}
# Compute the correct answers.
probs = logits.contiguous().float().softmax(dim=1)
max_probs, preds = torch.max(probs, 1)
preds = preds.int()
preds, labels = slidingjoins(preds, max_probs, labels, slices_info, img_size)
_, performs = CFMatrix()(preds, labels, args.ignore_index)
loss_dict['performs'] = performs
return 0, loss_dict
# defined inside to capture output_predictions
def correct_answers_forward_step(batch, model):
args = get_args()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
assert not model.training
images, labels, slices_info, img_size = slidingcrops(images, labels)
# Forward model.
output_tensor = torch.cat([model(image) for image in torch.split(images, args.micro_batch_size)])
return output_tensor, partial(loss_func, labels, slices_info, img_size)
with torch.no_grad():
# For all the batches in the dataset.
performs = None
for _, batch in enumerate(dataloader):
loss_dicts = forward_backward_func(correct_answers_forward_step,
batch, model,
optimizer=None,
timers=None,
forward_only=True)
for loss_dict in loss_dicts:
if performs is None:
performs = loss_dict['performs']
else:
performs += loss_dict['performs']
for m in model:
m.train()
# Reduce.
if mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(performs,
group=mpu.get_data_parallel_group())
# Print on screen.
# performs[int(ch), :] = [nb_tp, nb_fp, nb_tn, nb_fn]
true_positive = performs[:, 0]
false_positive = performs[:, 1]
false_negative = performs[:, 3]
iou = true_positive / (true_positive + false_positive + false_negative)
miou = iou[~torch.isnan(iou)].mean()
return iou.tolist(), miou.item()
def accuracy_func_provider():
"""Provide function that calculates accuracies."""
args = get_args()
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
dataloader = build_data_loader(
valid_ds,
args.micro_batch_size,
num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1),
shuffle=False
)
def metrics_func(model, epoch):
print_rank_0("calculating metrics ...")
iou, miou = calculate_correct_answers(model, dataloader, epoch)
print_rank_last(
" >> |epoch: {}| overall: iou = {},"
"miou = {:.4f} %".format(epoch, iou, miou*100.0)
)
return metrics_func
def dump_output_data(data, iteration, writer):
for (output_tb, loss) in data:
# output_tb[output_tb < 0] = 0
# output_tb[output_tb > 1] = 1
writer.add_images("image-outputseg-realseg", output_tb,
global_step=None, walltime=None,
dataformats='NCHW')
"""Finetune/evaluate."""
finetune(
train_valid_datasets_provider,
model_provider,
forward_step=_cross_entropy_forward_step,
process_non_loss_data_func=dump_output_data,
end_of_epoch_callback_provider=accuracy_func_provider,
)
def main():
segmentation()