File size: 4,555 Bytes
205a7af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import pprint
import numpy as np
from . import viz2d
from .tools import RadioHideTool, ToggleTool, __plot_dict__
# flake8: noqa
# mypy: ignore-errors
class FormatPrinter(pprint.PrettyPrinter):
def __init__(self, formats):
super(FormatPrinter, self).__init__()
self.formats = formats
def format(self, obj, ctx, maxlvl, lvl):
if type(obj) in self.formats:
return self.formats[type(obj)] % obj, 1, 0
return pprint.PrettyPrinter.format(self, obj, ctx, maxlvl, lvl)
class TwoViewFrame:
default_conf = {
"default": "image",
"summary_visible": False,
}
plot_dict = __plot_dict__
childs = []
event_to_image = [None, "image", "horizon_line", "lat_pred", "lat_gt"]
def __init__(self, conf, data, preds, title=None, event=1, summaries=None):
self.conf = conf
self.data = data
self.preds = preds
self.names = list(preds.keys())
self.plot = self.event_to_image[event]
self.summaries = summaries
self.fig, self.axes, self.summary_arts = self.init_frame()
if title is not None:
self.fig.canvas.manager.set_window_title(title)
keys = None
for _, pred in preds.items():
keys = set(pred.keys()) if keys is None else keys.intersection(pred.keys())
keys = keys.union(data.keys())
self.options = [k for k, v in self.plot_dict.items() if set(v.required_keys).issubset(keys)]
self.handle = None
self.radios = self.fig.canvas.manager.toolmanager.add_tool(
"switch plot",
RadioHideTool,
options=self.options,
callback_fn=self.draw,
active=conf.default,
keymap="R",
)
self.toggle_summary = self.fig.canvas.manager.toolmanager.add_tool(
"toggle summary",
ToggleTool,
toggled=self.conf.summary_visible,
callback_fn=self.set_summary_visible,
keymap="t",
)
if self.fig.canvas.manager.toolbar is not None:
self.fig.canvas.manager.toolbar.add_tool("switch plot", "navigation")
self.draw(conf.default)
def init_frame(self):
"""initialize frame"""
imgs = [[self.data["image"][0].permute(1, 2, 0) for _ in self.names]]
# imgs = [imgs for _ in self.names] # repeat for each model
fig, axes = viz2d.plot_image_grid(imgs, return_fig=True, titles=None, figs=5)
[viz2d.add_text(i, n, axes=axes[0]) for i, n in enumerate(self.names)]
fig.canvas.mpl_connect("pick_event", self.click_artist)
if self.summaries is not None:
font_size = 7
formatter = FormatPrinter({np.float32: "%.4f", np.float64: "%.4f"})
toggle_artists = [
viz2d.add_text(
i,
formatter.pformat(self.summaries[n]),
axes=axes[0],
pos=(0.01, 0.01),
va="bottom",
backgroundcolor=(0, 0, 0, 0.5),
visible=self.conf.summary_visible,
fs=font_size,
)
for i, n in enumerate(self.names)
]
else:
toggle_artists = []
return fig, axes, toggle_artists
def draw(self, value):
"""redraw content in frame"""
self.clear()
self.conf.default = value
self.handle = self.plot_dict[value](self.fig, self.axes, self.data, self.preds)
return self.handle
def clear(self):
if self.handle is not None:
try:
self.handle.clear()
except AttributeError:
pass
self.handle = None
for row in self.axes:
for ax in row:
[li.remove() for li in ax.lines]
[c.remove() for c in ax.collections]
self.fig.artists.clear()
self.fig.canvas.draw_idle()
self.handle = None
def click_artist(self, event):
art = event.artist
select = art.get_arrowstyle().arrow == "-"
art.set_arrowstyle("<|-|>" if select else "-")
if select:
art.set_zorder(1)
if hasattr(self.handle, "click_artist"):
self.handle.click_artist(event)
self.fig.canvas.draw_idle()
def set_summary_visible(self, visible):
self.conf.summary_visible = visible
[s.set_visible(visible) for s in self.summary_arts]
self.fig.canvas.draw_idle()
|