File size: 4,656 Bytes
99a05f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from tqdm import tqdm
from utils.metrics import metric, precision_recall_f1score, det_error_metric
import torch
import numpy as np
from vis.visualize import gen_render


def trainer(epoch, train_loader, solver, hparams, compute_metrics=False):

    total_epochs = hparams.TRAINING.NUM_EPOCHS
    print('Training Epoch {}/{}'.format(epoch, total_epochs))

    length = len(train_loader)
    iterator = tqdm(enumerate(train_loader), total=length, leave=False, desc=f'Training Epoch: {epoch}/{total_epochs}')
    for step, batch in iterator:
        losses, output = solver.optimize(batch)
    return losses, output

@torch.no_grad()
def evaluator(val_loader, solver, hparams, epoch=0, dataset_name='Unknown', normalize=True, return_dict=False):
    total_epochs = hparams.TRAINING.NUM_EPOCHS

    batch_size = val_loader.batch_size
    dataset_size = len(val_loader.dataset)
    print(f'Dataset size: {dataset_size}')

    val_epoch_cont_pre = np.zeros(dataset_size)
    val_epoch_cont_rec = np.zeros(dataset_size)
    val_epoch_cont_f1 = np.zeros(dataset_size)
    val_epoch_fp_geo_err = np.zeros(dataset_size)
    val_epoch_fn_geo_err = np.zeros(dataset_size)
    if hparams.TRAINING.CONTEXT:
        val_epoch_sem_iou = np.zeros(dataset_size)
        val_epoch_part_iou = np.zeros(dataset_size)

    val_epoch_cont_loss = np.zeros(dataset_size)
    
    total_time = 0

    rend_images = []

    eval_dict = {}

    length = len(val_loader)
    iterator = tqdm(enumerate(val_loader), total=length, leave=False, desc=f'Evaluating {dataset_name.capitalize()} Epoch: {epoch}/{total_epochs}')
    for step, batch in iterator:
        curr_batch_size = batch['img'].shape[0]
        losses, output, time_taken = solver.evaluate(batch)

        val_epoch_cont_loss[step * batch_size:step * batch_size + curr_batch_size] = losses['cont_loss'].cpu().numpy()

        # compute metrics
        contact_labels_3d = output['contact_labels_3d_gt']
        has_contact_3d = output['has_contact_3d']
        # check if any value in has_contact_3d tensor is 0
        assert torch.any(has_contact_3d == 0) == False, 'has_contact_3d tensor has 0 values'

        contact_labels_3d_pred = output['contact_labels_3d_pred']
        if hparams.TRAINING.CONTEXT:
            sem_mask_gt = output['sem_mask_gt']
            sem_seg_pred = output['sem_mask_pred']
            part_mask_gt = output['part_mask_gt']
            part_seg_pred = output['part_mask_pred']

        cont_pre, cont_rec, cont_f1 = precision_recall_f1score(contact_labels_3d, contact_labels_3d_pred)
        fp_geo_err, fn_geo_err = det_error_metric(contact_labels_3d_pred, contact_labels_3d)
        if hparams.TRAINING.CONTEXT:
            sem_iou = metric(sem_mask_gt, sem_seg_pred)
            part_iou = metric(part_mask_gt, part_seg_pred)

        val_epoch_cont_pre[step * batch_size:step * batch_size + curr_batch_size] = cont_pre.cpu().numpy()
        val_epoch_cont_rec[step * batch_size:step * batch_size + curr_batch_size] = cont_rec.cpu().numpy()
        val_epoch_cont_f1[step * batch_size:step * batch_size + curr_batch_size] = cont_f1.cpu().numpy()
        val_epoch_fp_geo_err[step * batch_size:step * batch_size + curr_batch_size] = fp_geo_err.cpu().numpy()
        val_epoch_fn_geo_err[step * batch_size:step * batch_size + curr_batch_size] = fn_geo_err.cpu().numpy()
        if hparams.TRAINING.CONTEXT:
            val_epoch_sem_iou[step * batch_size:step * batch_size + curr_batch_size] = sem_iou.cpu().numpy()
            val_epoch_part_iou[step * batch_size:step * batch_size + curr_batch_size] = part_iou.cpu().numpy()
        
        total_time += time_taken

        # logging every summary_steps steps
        if step % hparams.VALIDATION.SUMMARY_STEPS == 0:
            if hparams.TRAINING.CONTEXT:
                rend = gen_render(output, normalize)
                rend_images.append(rend)

    eval_dict['cont_precision'] = np.sum(val_epoch_cont_pre) / dataset_size
    eval_dict['cont_recall'] = np.sum(val_epoch_cont_rec) / dataset_size
    eval_dict['cont_f1'] = np.sum(val_epoch_cont_f1) / dataset_size
    eval_dict['fp_geo_err'] = np.sum(val_epoch_fp_geo_err) / dataset_size
    eval_dict['fn_geo_err'] = np.sum(val_epoch_fn_geo_err) / dataset_size
    if hparams.TRAINING.CONTEXT:
        eval_dict['sem_iou'] = np.sum(val_epoch_sem_iou) / dataset_size
        eval_dict['part_iou'] = np.sum(val_epoch_part_iou) / dataset_size
        eval_dict['images'] = rend_images
    
    total_time /= dataset_size

    val_epoch_cont_loss = np.sum(val_epoch_cont_loss) / dataset_size
    if return_dict:
        return eval_dict, total_time
    return eval_dict['cont_f1']