Spaces:
Sleeping
Sleeping
from tqdm import tqdm | |
from utils.metrics import metric, precision_recall_f1score, det_error_metric | |
import torch | |
import numpy as np | |
from vis.visualize import gen_render | |
def trainer(epoch, train_loader, solver, hparams, compute_metrics=False): | |
total_epochs = hparams.TRAINING.NUM_EPOCHS | |
print('Training Epoch {}/{}'.format(epoch, total_epochs)) | |
length = len(train_loader) | |
iterator = tqdm(enumerate(train_loader), total=length, leave=False, desc=f'Training Epoch: {epoch}/{total_epochs}') | |
for step, batch in iterator: | |
losses, output = solver.optimize(batch) | |
return losses, output | |
def evaluator(val_loader, solver, hparams, epoch=0, dataset_name='Unknown', normalize=True, return_dict=False): | |
total_epochs = hparams.TRAINING.NUM_EPOCHS | |
batch_size = val_loader.batch_size | |
dataset_size = len(val_loader.dataset) | |
print(f'Dataset size: {dataset_size}') | |
val_epoch_cont_pre = np.zeros(dataset_size) | |
val_epoch_cont_rec = np.zeros(dataset_size) | |
val_epoch_cont_f1 = np.zeros(dataset_size) | |
val_epoch_fp_geo_err = np.zeros(dataset_size) | |
val_epoch_fn_geo_err = np.zeros(dataset_size) | |
if hparams.TRAINING.CONTEXT: | |
val_epoch_sem_iou = np.zeros(dataset_size) | |
val_epoch_part_iou = np.zeros(dataset_size) | |
val_epoch_cont_loss = np.zeros(dataset_size) | |
total_time = 0 | |
rend_images = [] | |
eval_dict = {} | |
length = len(val_loader) | |
iterator = tqdm(enumerate(val_loader), total=length, leave=False, desc=f'Evaluating {dataset_name.capitalize()} Epoch: {epoch}/{total_epochs}') | |
for step, batch in iterator: | |
curr_batch_size = batch['img'].shape[0] | |
losses, output, time_taken = solver.evaluate(batch) | |
val_epoch_cont_loss[step * batch_size:step * batch_size + curr_batch_size] = losses['cont_loss'].cpu().numpy() | |
# compute metrics | |
contact_labels_3d = output['contact_labels_3d_gt'] | |
has_contact_3d = output['has_contact_3d'] | |
# check if any value in has_contact_3d tensor is 0 | |
assert torch.any(has_contact_3d == 0) == False, 'has_contact_3d tensor has 0 values' | |
contact_labels_3d_pred = output['contact_labels_3d_pred'] | |
if hparams.TRAINING.CONTEXT: | |
sem_mask_gt = output['sem_mask_gt'] | |
sem_seg_pred = output['sem_mask_pred'] | |
part_mask_gt = output['part_mask_gt'] | |
part_seg_pred = output['part_mask_pred'] | |
cont_pre, cont_rec, cont_f1 = precision_recall_f1score(contact_labels_3d, contact_labels_3d_pred) | |
fp_geo_err, fn_geo_err = det_error_metric(contact_labels_3d_pred, contact_labels_3d) | |
if hparams.TRAINING.CONTEXT: | |
sem_iou = metric(sem_mask_gt, sem_seg_pred) | |
part_iou = metric(part_mask_gt, part_seg_pred) | |
val_epoch_cont_pre[step * batch_size:step * batch_size + curr_batch_size] = cont_pre.cpu().numpy() | |
val_epoch_cont_rec[step * batch_size:step * batch_size + curr_batch_size] = cont_rec.cpu().numpy() | |
val_epoch_cont_f1[step * batch_size:step * batch_size + curr_batch_size] = cont_f1.cpu().numpy() | |
val_epoch_fp_geo_err[step * batch_size:step * batch_size + curr_batch_size] = fp_geo_err.cpu().numpy() | |
val_epoch_fn_geo_err[step * batch_size:step * batch_size + curr_batch_size] = fn_geo_err.cpu().numpy() | |
if hparams.TRAINING.CONTEXT: | |
val_epoch_sem_iou[step * batch_size:step * batch_size + curr_batch_size] = sem_iou.cpu().numpy() | |
val_epoch_part_iou[step * batch_size:step * batch_size + curr_batch_size] = part_iou.cpu().numpy() | |
total_time += time_taken | |
# logging every summary_steps steps | |
if step % hparams.VALIDATION.SUMMARY_STEPS == 0: | |
if hparams.TRAINING.CONTEXT: | |
rend = gen_render(output, normalize) | |
rend_images.append(rend) | |
eval_dict['cont_precision'] = np.sum(val_epoch_cont_pre) / dataset_size | |
eval_dict['cont_recall'] = np.sum(val_epoch_cont_rec) / dataset_size | |
eval_dict['cont_f1'] = np.sum(val_epoch_cont_f1) / dataset_size | |
eval_dict['fp_geo_err'] = np.sum(val_epoch_fp_geo_err) / dataset_size | |
eval_dict['fn_geo_err'] = np.sum(val_epoch_fn_geo_err) / dataset_size | |
if hparams.TRAINING.CONTEXT: | |
eval_dict['sem_iou'] = np.sum(val_epoch_sem_iou) / dataset_size | |
eval_dict['part_iou'] = np.sum(val_epoch_part_iou) / dataset_size | |
eval_dict['images'] = rend_images | |
total_time /= dataset_size | |
val_epoch_cont_loss = np.sum(val_epoch_cont_loss) / dataset_size | |
if return_dict: | |
return eval_dict, total_time | |
return eval_dict['cont_f1'] |