File size: 4,186 Bytes
499e141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# -*- coding: utf-8 -*-
# @Author  : xuelun

import os
import argparse

import numpy as np

from os.path import join
from datetime import datetime

angular_thresholds = ['5.0°']
dist_thresholds = ['0.1m']
intt = lambda x: list(map(int, x))
floatt = lambda x: list(map(float, x))
strr = lambda x: list(map(lambda x:f'{x:.18f}', x))

datasets = [
    'GL3D',
    'BlendedMVS',
    'ETH3DI',
    'ETH3DO',
    'KITTI',
    'RobotcarWeather',
    'RobotcarSeason',
    'RobotcarNight',
    'Multi-FoV',
    'SceneNetRGBD',
    'ICL-NUIM',
    'GTA-SfM',
]


def error_auc(errs0, errs1, thres, metric):
    if isinstance(errs0, list): errs0 = np.array(errs0)
    if isinstance(errs1, list): errs1 = np.array(errs1)
    if any(np.isnan(errs0)): errs0[np.isnan(errs0)] = 180
    if any(np.isnan(errs1)): errs1[np.isnan(errs1)] = 180
    if any(np.isinf(errs0)): errs0[np.isinf(errs0)] = 180
    if any(np.isinf(errs1)): errs1[np.isinf(errs1)] = 180
    errors = np.max(np.stack([errs0, errs1]), axis=0)
    errors = [0] + sorted(list(errors))
    recall = list(np.linspace(0, 1, len(errors)))

    aucs = []
    for thr in thres:
        thr = float(thr[:-1])
        last_index = np.searchsorted(errors, thr)
        y = recall[:last_index] + [recall[last_index-1]]
        x = errors[:last_index] + [thr]
        aucs.append(np.trapz(y, x) / thr)

    return {f'{metric}@ {t}': auc for t, auc in zip(thres, aucs)}


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--dir', type=str, default='.')
    parser.add_argument('--wid', type=str, required=True)
    parser.add_argument('--version', type=str, default=None)
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--log', action='store_true')
    parser.add_argument('--sceids', type=str, choices=datasets, nargs='+',
                        default=None, help=f'Test Datasets: {datasets}', )
    opt = parser.parse_args()

    dir = opt.dir
    wid = opt.wid
    version = opt.version

    _data = \
        {
            x.rpartition('.txt')[0].split()[2]:x for x in
            [
                d for d in os.listdir(dir) if not os.path.isdir(os.path.join(dir, d))
            ] if wid == x.rpartition('.txt')[0].split()[1] and version is not None and version == x.rpartition('.txt')[0].split()[-1]
        }
    _data = {k:_data[k] for k in datasets if k in _data.keys()}

    sceids = opt.sceids
    sceids = sceids if sceids is not None else _data.keys()
    results = {}
    for sceid in sceids:
        results[sceid] = {}
        if not opt.verbose: print('{:^13} {}'.format(sceid, wid))

        # read txt
        with open(join(dir, _data[sceid]), 'r') as f:
            data = f.readlines()
        head = data[0].split()
        content = [x.split() for x in data[1:]]
        details = {k: [] for k in head[3:]}

        stacks = []
        for x in content:
            ids = x[0]
            if ids in stacks: continue

            for k, v in zip(head[3:], x[3:]): details[k].append(v)
            stacks.append(ids)

        mAP = error_auc(floatt(details['R_errs']), floatt(details['t_errs']), angular_thresholds, 'auc')
        for k, v in mAP.items(): results[sceid][k] = v

    # print head
    output = ''

    num = 56+25*len(sceids)
    output += '='*num
    output += "\n"

    output += '{:<25}'.format(datetime.now().strftime("%Y-%m-%d, %H:%M:%S"))
    output += '{:<15} '.format('Model')
    output += '{:<14} '.format('Metric')
    for sceid in sceids: output += '{:<25} '.format(sceid)
    output += "\n"

    output += '-'*num
    output += "\n"

    for k in list(results.values())[0].keys():
        output += '{:<25}'.format(datetime.now().strftime("%Y-%m-%d, %H:%M:%S")) if opt.log else '{:<25}'.format(' ')
        output += '{:<15} '.format(wid)
        output += '{:<14} '.format(k)

        for sceid in sceids:
            output += '{:<25} '.format(results[sceid][k])
        output += "\n"

    output += '='*num
    output += "\n"
    output += "\n"

    if opt.verbose:
        print(output)

    if opt.log:
        path = 'ANALYSIS RESULTS.txt'
        with open(path, 'a') as file:
            file.write(output)