|
import functools |
|
import traceback |
|
from copy import deepcopy |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from matplotlib.widgets import Button |
|
from omegaconf import OmegaConf |
|
|
|
from ..datasets.base_dataset import collate |
|
from ..models.cache_loader import CacheLoader |
|
from .tools import RadioHideTool |
|
|
|
|
|
|
|
|
|
|
|
class GlobalFrame: |
|
default_conf = { |
|
"x": "???", |
|
"y": "???", |
|
"diff": False, |
|
"child": {}, |
|
"remove_outliers": False, |
|
} |
|
|
|
child_frame = None |
|
|
|
childs = [] |
|
|
|
lines = [] |
|
|
|
scatters = {} |
|
|
|
def __init__(self, conf, results, loader, predictions, title=None, child_frame=None): |
|
self.child_frame = child_frame |
|
if self.child_frame is not None: |
|
|
|
self.default_conf["child"] = self.child_frame.default_conf |
|
|
|
self.conf = OmegaConf.merge(self.default_conf, conf) |
|
self.results = results |
|
self.loader = loader |
|
self.predictions = predictions |
|
self.metrics = set() |
|
for k, v in results.items(): |
|
self.metrics.update(v.keys()) |
|
self.metrics = sorted(list(self.metrics)) |
|
|
|
self.conf.x = conf["x"] or self.metrics[0] |
|
self.conf.y = conf["y"] or self.metrics[1] |
|
|
|
assert self.conf.x in self.metrics |
|
assert self.conf.y in self.metrics |
|
|
|
self.names = list(results) |
|
self.fig, self.axes = self.init_frame() |
|
if title is not None: |
|
self.fig.canvas.manager.set_window_title(title) |
|
|
|
self.xradios = self.fig.canvas.manager.toolmanager.add_tool( |
|
"x", |
|
RadioHideTool, |
|
options=self.metrics, |
|
callback_fn=self.update_x, |
|
active=self.conf.x, |
|
keymap="x", |
|
) |
|
|
|
self.yradios = self.fig.canvas.manager.toolmanager.add_tool( |
|
"y", |
|
RadioHideTool, |
|
options=self.metrics, |
|
callback_fn=self.update_y, |
|
active=self.conf.y, |
|
keymap="y", |
|
) |
|
if self.fig.canvas.manager.toolbar is not None: |
|
self.fig.canvas.manager.toolbar.add_tool("x", "navigation") |
|
self.fig.canvas.manager.toolbar.add_tool("y", "navigation") |
|
|
|
def init_frame(self): |
|
"""initialize frame""" |
|
fig, ax = plt.subplots() |
|
ax.set_title("click on points") |
|
diffb_ax = fig.add_axes([0.01, 0.02, 0.12, 0.06]) |
|
self.diffb = Button(diffb_ax, label="diff_only") |
|
self.diffb.on_clicked(self.diff_clicked) |
|
fig.canvas.mpl_connect("pick_event", self.on_scatter_pick) |
|
fig.canvas.mpl_connect("motion_notify_event", self.hover) |
|
return fig, ax |
|
|
|
def draw(self): |
|
"""redraw content in frame""" |
|
self.scatters = {} |
|
self.axes.clear() |
|
self.axes.set_xlabel(self.conf.x) |
|
self.axes.set_ylabel(self.conf.y) |
|
|
|
refx = 0.0 |
|
refy = 0.0 |
|
x_cat = isinstance(self.results[self.names[0]][self.conf.x][0], (bytes, str)) |
|
y_cat = isinstance(self.results[self.names[0]][self.conf.y][0], (bytes, str)) |
|
|
|
if self.conf.diff: |
|
if not x_cat: |
|
refx = np.array(self.results[self.names[0]][self.conf.x]) |
|
if not y_cat: |
|
refy = np.array(self.results[self.names[0]][self.conf.y]) |
|
for name in list(self.results.keys()): |
|
x = np.array(self.results[name][self.conf.x]) |
|
y = np.array(self.results[name][self.conf.y]) |
|
|
|
if x_cat and np.char.isdigit(x.astype(str)).all(): |
|
x = x.astype(int) |
|
if y_cat and np.char.isdigit(y.astype(str)).all(): |
|
y = y.astype(int) |
|
|
|
x = x if x_cat else x - refx |
|
y = y if y_cat else y - refy |
|
|
|
(s,) = self.axes.plot(x, y, "o", markersize=3, label=name, picker=True, pickradius=5) |
|
self.scatters[name] = s |
|
|
|
if x_cat and not y_cat: |
|
xunique, ind, xinv, xbin = np.unique( |
|
x, return_inverse=True, return_counts=True, return_index=True |
|
) |
|
ybin = np.bincount(xinv, weights=y) |
|
sort_ax = np.argsort(ind) |
|
self.axes.step( |
|
xunique[sort_ax], |
|
(ybin / xbin)[sort_ax], |
|
where="mid", |
|
color=s.get_color(), |
|
) |
|
|
|
if not x_cat: |
|
xavg = np.nan_to_num(x).mean() |
|
self.axes.axvline(xavg, c=s.get_color(), zorder=1, alpha=1.0) |
|
xmed = np.median(x - refx) |
|
self.axes.axvline( |
|
xmed, |
|
c=s.get_color(), |
|
zorder=0, |
|
alpha=0.5, |
|
linestyle="dashed", |
|
visible=False, |
|
) |
|
|
|
if not y_cat: |
|
yavg = np.nan_to_num(y).mean() |
|
self.axes.axhline(yavg, c=s.get_color(), zorder=1, alpha=0.5) |
|
ymed = np.median(y - refy) |
|
self.axes.axhline( |
|
ymed, |
|
c=s.get_color(), |
|
zorder=0, |
|
alpha=0.5, |
|
linestyle="dashed", |
|
visible=False, |
|
) |
|
if x_cat and x.dtype == object and xunique.shape[0] > 5: |
|
self.axes.set_xticklabels(xunique[sort_ax], rotation=90) |
|
self.axes.legend() |
|
|
|
def on_scatter_pick(self, handle): |
|
try: |
|
art = handle.artist |
|
try: |
|
event = handle.mouseevent.button.value |
|
except AttributeError: |
|
return |
|
name = art.get_label() |
|
ind = handle.ind[0] |
|
|
|
self.spawn_child(name, ind, event=event) |
|
except Exception: |
|
traceback.print_exc() |
|
exit(0) |
|
|
|
def spawn_child(self, model_name, ind, event=None): |
|
[line.remove() for line in self.lines] |
|
self.lines = [] |
|
|
|
x_source = self.scatters[model_name].get_xdata()[ind] |
|
y_source = self.scatters[model_name].get_ydata()[ind] |
|
for oname in self.names: |
|
xn = self.scatters[oname].get_xdata()[ind] |
|
yn = self.scatters[oname].get_ydata()[ind] |
|
|
|
(ln,) = self.axes.plot([x_source, xn], [y_source, yn], "r") |
|
self.lines.append(ln) |
|
|
|
self.fig.canvas.draw_idle() |
|
|
|
if self.child_frame is None: |
|
return |
|
|
|
data = collate([self.loader.dataset[ind]]) |
|
|
|
preds = { |
|
name: CacheLoader({"path": str(pfile), "add_data_path": False})(data) |
|
for name, pfile in self.predictions.items() |
|
} |
|
summaries_i = { |
|
name: {k: v[ind] for k, v in res.items() if k != "names"} |
|
for name, res in self.results.items() |
|
} |
|
frame = self.child_frame( |
|
self.conf.child, |
|
deepcopy(data), |
|
preds, |
|
title=str(data["name"][0]), |
|
event=event, |
|
summaries=summaries_i, |
|
) |
|
|
|
frame.fig.canvas.mpl_connect( |
|
"key_press_event", |
|
functools.partial(self.on_childframe_key_event, frame=frame, ind=ind, event=event), |
|
) |
|
self.childs.append(frame) |
|
self.childs[-1].fig.show() |
|
|
|
def hover(self, event): |
|
if event.inaxes != self.axes: |
|
return |
|
|
|
for _, s in self.scatters.items(): |
|
cont, ind = s.contains(event) |
|
if cont: |
|
ind = ind["ind"][0] |
|
xdata, ydata = s.get_data() |
|
[line.remove() for line in self.lines] |
|
self.lines = [] |
|
|
|
for oname in self.names: |
|
xn = self.scatters[oname].get_xdata()[ind] |
|
yn = self.scatters[oname].get_ydata()[ind] |
|
|
|
(ln,) = self.axes.plot( |
|
[xdata[ind], xn], |
|
[ydata[ind], yn], |
|
"black", |
|
zorder=0, |
|
alpha=0.5, |
|
) |
|
self.lines.append(ln) |
|
self.fig.canvas.draw_idle() |
|
break |
|
|
|
def diff_clicked(self, args): |
|
self.conf.diff = not self.conf.diff |
|
self.draw() |
|
self.fig.canvas.draw_idle() |
|
|
|
def update_x(self, x): |
|
self.conf.x = x |
|
self.draw() |
|
|
|
def update_y(self, y): |
|
self.conf.y = y |
|
self.draw() |
|
|
|
def on_childframe_key_event(self, key_event, frame, ind, event): |
|
if key_event.key == "delete": |
|
plt.close(frame.fig) |
|
self.childs.remove(frame) |
|
elif key_event.key in ["left", "right", "shift+left", "shift+right"]: |
|
key = key_event.key |
|
if key.startswith("shift+"): |
|
key = key.replace("shift+", "") |
|
else: |
|
plt.close(frame.fig) |
|
self.childs.remove(frame) |
|
new_ind = ind + 1 if key_event.key == "right" else ind - 1 |
|
self.spawn_child( |
|
self.names[0], |
|
new_ind % len(self.loader), |
|
event=event, |
|
) |
|
|