File size: 6,547 Bytes
205a7af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""Visualization of predicted and ground truth for a single batch."""

from typing import Any, Dict

import numpy as np
import torch

from siclib.geometry.perspective_fields import get_latitude_field
from siclib.models.utils.metrics import latitude_error, up_error
from siclib.utils.conversions import rad2deg
from siclib.utils.tensor import batch_to_device
from siclib.visualization.viz2d import (
    plot_confidences,
    plot_heatmaps,
    plot_image_grid,
    plot_latitudes,
    plot_vector_fields,
)


def make_up_figure(
    pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
) -> Dict[str, Any]:
    """Get predicted and ground truth up fields and errors.

    Args:
        pred (Dict[str, torch.Tensor]): Predicted up field.
        data (Dict[str, torch.Tensor]): Ground truth up field.
        n_pairs (int): Number of pairs to visualize.

    Returns:
        Dict[str, Any]: Dictionary with figure.
    """
    pred = batch_to_device(pred, "cpu", detach=True)
    data = batch_to_device(data, "cpu", detach=True)

    n_pairs = min(n_pairs, len(data["image"]))

    if "up_field" not in pred.keys():
        return {}

    errors = up_error(pred["up_field"], data["up_field"])

    up_fields = []
    for i in range(n_pairs):
        row = [data["up_field"][i], pred["up_field"][i], errors[i]]
        titles = ["Up GT", "Up Pred", "Up Error"]

        if "up_confidence" in pred.keys():
            row += [pred["up_confidence"][i]]
            titles += ["Up Confidence"]

        row = [r.float().numpy() if isinstance(r, torch.Tensor) else r for r in row]
        up_fields.append(row)

    # create figure
    N, M = len(up_fields), len(up_fields[0]) + 1
    imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)]
    fig, ax = plot_image_grid(imgs, titles=[["Image"] + titles] * N, return_fig=True, set_lim=True)
    ax = np.array(ax)

    for i in range(n_pairs):
        plot_vector_fields(up_fields[i][:2], axes=ax[i, [1, 2]])
        plot_heatmaps([up_fields[i][2]], cmap="turbo", colorbar=True, axes=ax[i, [3]])

        if "up_confidence" in pred.keys():
            plot_confidences([up_fields[i][3]], axes=ax[i, [4]])

    return {"up": fig}


def make_latitude_figure(
    pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
) -> Dict[str, Any]:
    """Get predicted and ground truth latitude fields and errors.

    Args:
        pred (Dict[str, torch.Tensor]): Predicted latitude field.
        data (Dict[str, torch.Tensor]): Ground truth latitude field.
        n_pairs (int, optional): Number of pairs to visualize. Defaults to 2.

    Returns:
        Dict[str, Any]: Dictionary with figure.
    """
    pred = batch_to_device(pred, "cpu", detach=True)
    data = batch_to_device(data, "cpu", detach=True)

    n_pairs = min(n_pairs, len(data["image"]))
    latitude_fields = []

    if "latitude_field" not in pred.keys():
        return {}

    errors = latitude_error(pred["latitude_field"], data["latitude_field"])
    for i in range(n_pairs):
        row = [
            rad2deg(data["latitude_field"][i][0]),
            rad2deg(pred["latitude_field"][i][0]),
            errors[i],
        ]
        titles = ["Latitude GT", "Latitude Pred", "Latitude Error"]

        if "latitude_confidence" in pred.keys():
            row += [pred["latitude_confidence"][i]]
            titles += ["Latitude Confidence"]

        row = [r.float().numpy() if isinstance(r, torch.Tensor) else r for r in row]
        latitude_fields.append(row)

    # create figure
    N, M = len(latitude_fields), len(latitude_fields[0]) + 1
    imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)]
    fig, ax = plot_image_grid(imgs, titles=[["Image"] + titles] * N, return_fig=True, set_lim=True)
    ax = np.array(ax)

    for i in range(n_pairs):
        plot_latitudes(latitude_fields[i][:2], is_radians=False, axes=ax[i, [1, 2]])
        plot_heatmaps([latitude_fields[i][2]], cmap="turbo", colorbar=True, axes=ax[i, [3]])

        if "latitude_confidence" in pred.keys():
            plot_confidences([latitude_fields[i][3]], axes=ax[i, [4]])

    return {"latitude": fig}


def make_camera_figure(
    pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
) -> Dict[str, Any]:
    """Get predicted and ground truth camera parameters.

    Args:
        pred (Dict[str, torch.Tensor]): Predicted camera parameters.
        data (Dict[str, torch.Tensor]): Ground truth camera parameters.
        n_pairs (int, optional): Number of pairs to visualize. Defaults to 2.

    Returns:
        Dict[str, Any]: Dictionary with figure.
    """
    pred = batch_to_device(pred, "cpu", detach=True)
    data = batch_to_device(data, "cpu", detach=True)

    n_pairs = min(n_pairs, len(data["image"]))

    if "camera" not in pred.keys():
        return {}

    latitudes = []
    for i in range(n_pairs):
        titles = ["Cameras GT"]
        row = [get_latitude_field(data["camera"][i], data["gravity"][i])]

        if "camera" in pred.keys() and "gravity" in pred.keys():
            row += [get_latitude_field(pred["camera"][i], pred["gravity"][i])]
            titles += ["Cameras Pred"]

        row = [rad2deg(r).squeeze(-1).float().numpy()[0] for r in row]
        latitudes.append(row)

    # create figure
    N, M = len(latitudes), len(latitudes[0]) + 1
    imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)]
    fig, ax = plot_image_grid(imgs, titles=[["Image"] + titles] * N, return_fig=True, set_lim=True)
    ax = np.array(ax)

    for i in range(n_pairs):
        plot_latitudes(latitudes[i], is_radians=False, axes=ax[i, 1:])

    return {"camera": fig}


def make_perspective_figures(
    pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
) -> Dict[str, Any]:
    """Get predicted and ground truth perspective fields.

    Args:
        pred (Dict[str, torch.Tensor]): Predicted perspective fields.
        data (Dict[str, torch.Tensor]): Ground truth perspective fields.
        n_pairs (int, optional): Number of pairs to visualize. Defaults to 2.

    Returns:
        Dict[str, Any]: Dictionary with figure.
    """
    n_pairs = min(n_pairs, len(data["image"]))
    figures = make_up_figure(pred, data, n_pairs)
    figures |= make_latitude_figure(pred, data, n_pairs)
    figures |= make_camera_figure(pred, data, n_pairs)

    {f.tight_layout() for f in figures.values()}

    return figures