deepvats / dvats_xai /visualization.py
misantamaria's picture
trying to fix cuda error
6d51833
raw
history blame
2.41 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/visualization.ipynb.
# %% auto 0
__all__ = ['plot_TS', 'plot_validation_ts_ae', 'plot_mask']
# %% ../nbs/visualization.ipynb 3
from fastcore.all import *
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
# %% ../nbs/visualization.ipynb 6
@delegates(pd.DataFrame.plot)
def plot_TS(df:pd.core.frame.DataFrame, **kwargs):
df.plot(subplots=True, **kwargs)
plt.show()
# %% ../nbs/visualization.ipynb 8
def plot_validation_ts_ae(prediction:np.array, original:np.array, title_str = "Validation plot", fig_size = (15,15), anchor = (-0.01, 0.89), window_num = 0, return_fig=True, title_pos = 0.9):
# Create the figure
fig = plt.figure(figsize=(fig_size[0],fig_size[1]))
# Create the subplot axes
axes = fig.subplots(nrows=original.shape[2], ncols=1)
# We iterate over the sensor data and plot both the original and the prediction
for i,ax in zip(range(original.shape[2]),fig.axes):
ax.plot(original[window_num,:,i], label='Original Data')
ax.plot(prediction[window_num,:,i], label='Prediction')
# Handle the legend configuration and position
lines, labels = fig.axes[-1].get_legend_handles_labels()
fig.legend(lines, labels,loc='upper left', ncol=2)
# Write the plot title (and position it closer to the top of the graph)
fig.suptitle(title_str, y = title_pos)
# Tight results:
fig.tight_layout()
# Returns
if return_fig:
return fig
fig
return None
# %% ../nbs/visualization.ipynb 12
def plot_mask(mask, i=0, fig_size=(10,10), title_str="Mask", return_fig=False):
"""
Plot the mask passed as argument. The mask is a 3D boolean tensor. The first
dimension is the window number (or item index), the second is the variable, and the third is the time step.
Input:
mask: 3D boolean tensor
i: index of the window to plot
fig_size: size of the figure
title_str: title of the plot
return_fig: if True, returns the figure
Output:
if return_fig is True, returns the figure, otherwise, it does not return anything
"""
plt.figure(figsize=fig_size)
plt.pcolormesh(mask[i], cmap='cool')
plt.title(f'{title_str} {i}, mean: {mask[0].float().mean().item():.3f}')
if return_fig:
return plt.gcf()
else:
plt.show()
return None