deepvats / dvats /xai.py
misantamaria's picture
bring dvats & requirements & entrypoint
7399708
raw
history blame
36.1 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/xai.ipynb.
# %% auto 0
__all__ = ['get_embeddings', 'get_dataset', 'umap_parameters', 'get_prjs', 'plot_projections', 'plot_projections_clusters',
'calculate_cluster_stats', 'anomaly_score', 'detector', 'plot_anomaly_scores_distribution',
'plot_clusters_with_anomalies', 'update_plot', 'plot_clusters_with_anomalies_interactive_plot',
'get_df_selected', 'shift_datetime', 'get_dateformat', 'get_anomalies', 'get_anomaly_styles',
'InteractiveAnomalyPlot', 'plot_save', 'plot_initial_config', 'merge_overlapping_windows',
'InteractiveTSPlot', 'add_selected_features', 'add_windows', 'setup_style', 'toggle_trace',
'set_features_buttons', 'move_left', 'move_right', 'move_down', 'move_up', 'delta_x_bigger',
'delta_y_bigger', 'delta_x_lower', 'delta_y_lower', 'add_movement_buttons', 'setup_boxes', 'initial_plot',
'show']
# %% ../nbs/xai.ipynb 1
#Weight & Biases
import wandb
#Yaml
from yaml import load, FullLoader
#Embeddings
from .all import *
from tsai.data.preparation import prepare_forecasting_data
from tsai.data.validation import get_forecasting_splits
from fastcore.all import *
#Dimensionality reduction
from tsai.imports import *
#Clustering
import hdbscan
import time
from .dr import get_PCA_prjs, get_UMAP_prjs, get_TSNE_prjs
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import ipywidgets as widgets
from IPython.display import display
from functools import partial
from IPython.display import display, clear_output, HTML as IPHTML
from ipywidgets import Button, Output, VBox, HBox, HTML, Layout, FloatSlider
import plotly.graph_objs as go
import plotly.offline as py
import plotly.io as pio
#! pip install kaleido
import kaleido
# %% ../nbs/xai.ipynb 4
def get_embeddings(config_lrp, run_lrp, api, print_flag = False):
artifacts_gettr = run_lrp.use_artifact if config_lrp.use_wandb else api.artifact
emb_artifact = artifacts_gettr(config_lrp.emb_artifact, type='embeddings')
if print_flag: print(emb_artifact.name)
emb_config = emb_artifact.logged_by().config
return emb_artifact.to_obj(), emb_artifact, emb_config
# %% ../nbs/xai.ipynb 5
def get_dataset(
config_lrp,
config_emb,
config_dr,
run_lrp,
api,
print_flag = False
):
# Botch to use artifacts offline
artifacts_gettr = run_lrp.use_artifact if config_lrp.use_wandb else api.artifact
enc_artifact = artifacts_gettr(config_emb['enc_artifact'], type='learner')
if print_flag: print (enc_artifact.name)
## TODO: This only works when you run it two timeS! WTF?
try:
enc_learner = enc_artifact.to_obj()
except:
enc_learner = enc_artifact.to_obj()
## Restore artifact
enc_logger = enc_artifact.logged_by()
enc_artifact_train = artifacts_gettr(enc_logger.config['train_artifact'], type='dataset')
#cfg_.show_attrdict(enc_logger.config)
if enc_logger.config['valid_artifact'] is not None:
enc_artifact_valid = artifacts_gettr(enc_logger.config['valid_artifact'], type='dataset')
if print_flag: print("enc_artifact_valid:", enc_artifact_valid.name)
if print_flag: print("enc_artifact_train: ", enc_artifact_train.name)
if config_dr['dr_artifact'] is not None:
print("Is not none")
dr_artifact = artifacts_gettr(config_dr['enc_artifact'])
else:
dr_artifact = enc_artifact_train
if print_flag: print("DR artifact train: ", dr_artifact.name)
if print_flag: print("--> DR artifact name", dr_artifact.name)
dr_artifact
df = dr_artifact.to_df()
if print_flag: print("--> DR After to df", df.shape)
if print_flag: display(df.head())
return df, dr_artifact, enc_artifact, enc_learner
# %% ../nbs/xai.ipynb 6
def umap_parameters(config_dr, config):
umap_params_cpu = {
'n_neighbors' : config_dr.n_neighbors,
'min_dist' : config_dr.min_dist,
'random_state': np.uint64(822569775),
'metric': config_dr.metric,
#'a': 1.5769434601962196,
#'b': 0.8950608779914887,
#'metric_kwds': {'p': 2}, #No debería ser necesario, just in case
#'output_metric': 'euclidean',
'verbose': 4,
#'n_epochs': 200
}
umap_params_gpu = {
'n_neighbors' : config_dr.n_neighbors,
'min_dist' : config_dr.min_dist,
'random_state': np.uint64(1234),
'metric': config_dr.metric,
'a': 1.5769434601962196,
'b': 0.8950608779914887,
'target_metric': 'euclidean',
'target_n_neighbors': config_dr.n_neighbors,
'verbose': 4, #6, #CUML_LEVEL_TRACE
'n_epochs': 200*3*2,
'init': 'random',
'hash_input': True
}
if config_dr.cpu_flag:
umap_params = umap_params_cpu
else:
umap_params = umap_params_gpu
return umap_params
# %% ../nbs/xai.ipynb 7
def get_prjs(embs_no_nan, config_dr, config, print_flag = False):
umap_params = umap_parameters(config_dr, config)
prjs_pca = get_PCA_prjs(
X = embs_no_nan,
cpu = False,
print_flag = print_flag,
**umap_params
)
if print_flag:
print(prjs_pca.shape)
prjs_umap = get_UMAP_prjs(
input_data = prjs_pca,
cpu = config_dr.cpu_flag, #config_dr.cpu,
print_flag = print_flag,
**umap_params
)
if print_flag: prjs_umap.shape
return prjs_umap
# %% ../nbs/xai.ipynb 9
def plot_projections(prjs, umap_params, fig_size = (25,25)):
"Plot 2D projections thorugh a connected scatter plot"
df_prjs = pd.DataFrame(prjs, columns = ['x1', 'x2'])
fig = plt.figure(figsize=(fig_size[0],fig_size[1]))
ax = fig.add_subplot(111)
ax.scatter(df_prjs['x1'], df_prjs['x2'], marker='o', facecolors='none', edgecolors='b', alpha=0.1)
ax.plot(df_prjs['x1'], df_prjs['x2'], alpha=0.5, picker=1)
plt.title('DR params - n_neighbors:{:d} min_dist:{:f}'.format(
umap_params['n_neighbors'],umap_params['min_dist']))
return ax
# %% ../nbs/xai.ipynb 10
def plot_projections_clusters(prjs, clusters_labels, umap_params, fig_size = (25,25)):
"Plot 2D projections thorugh a connected scatter plot"
df_prjs = pd.DataFrame(prjs, columns = ['x1', 'x2'])
df_prjs['cluster'] = clusters_labels
fig = plt.figure(figsize=(fig_size[0],fig_size[1]))
ax = fig.add_subplot(111)
# Create a scatter plot for each cluster with different colors
unique_labels = df_prjs['cluster'].unique()
print(unique_labels)
for label in unique_labels:
cluster_data = df_prjs[df_prjs['cluster'] == label]
ax.scatter(cluster_data['x1'], cluster_data['x2'], label=f'Cluster {label}')
#ax.scatter(df_prjs['x1'], df_prjs['x2'], marker='o', facecolors='none', edgecolors='b', alpha=0.1)
#ax.plot(df_prjs['x1'], df_prjs['x2'], alpha=0.5, picker=1)
plt.title('DR params - n_neighbors:{:d} min_dist:{:f}'.format(
umap_params['n_neighbors'],umap_params['min_dist']))
return ax
# %% ../nbs/xai.ipynb 11
def calculate_cluster_stats(data, labels):
"""Computes the media and the standard deviation for every cluster."""
cluster_stats = {}
for label in np.unique(labels):
#members = data[labels == label]
members = data
mean = np.mean(members, axis = 0)
std = np.std(members, axis = 0)
cluster_stats[label] = (mean, std)
return cluster_stats
# %% ../nbs/xai.ipynb 12
def anomaly_score(point, cluster_stats, label):
"""Computes an anomaly score for each point."""
mean, std = cluster_stats[label]
return np.linalg.norm((point - mean) / std)
# %% ../nbs/xai.ipynb 13
def detector(data, labels):
"""Anomaly detection function."""
cluster_stats = calculate_cluster_stats(data, labels)
scores = []
for point, label in zip(data, labels):
score = anomaly_score(point, cluster_stats, label)
scores.append(score)
return np.array(scores)
# %% ../nbs/xai.ipynb 15
def plot_anomaly_scores_distribution(anomaly_scores):
"Plot the distribution of anomaly scores to check for normality"
plt.figure(figsize=(10, 6))
sns.histplot(anomaly_scores, kde=True, bins=30)
plt.title("Distribución de Anomaly Scores")
plt.xlabel("Anomaly Score")
plt.ylabel("Frecuencia")
plt.show()
# %% ../nbs/xai.ipynb 16
def plot_clusters_with_anomalies(prjs, clusters_labels, anomaly_scores, threshold, fig_size=(25, 25)):
"Plot 2D projections of clusters and superimpose anomalies"
df_prjs = pd.DataFrame(prjs, columns=['x1', 'x2'])
df_prjs['cluster'] = clusters_labels
df_prjs['anomaly'] = anomaly_scores > threshold
fig = plt.figure(figsize=(fig_size[0], fig_size[1]))
ax = fig.add_subplot(111)
# Plot each cluster with different colors
unique_labels = df_prjs['cluster'].unique()
for label in unique_labels:
cluster_data = df_prjs[df_prjs['cluster'] == label]
ax.scatter(cluster_data['x1'], cluster_data['x2'], label=f'Cluster {label}', alpha=0.7)
# Superimpose anomalies
anomalies = df_prjs[df_prjs['anomaly']]
ax.scatter(anomalies['x1'], anomalies['x2'], color='red', label='Anomalies', edgecolor='k', s=50)
plt.title('Clusters and anomalies')
plt.legend()
plt.show()
def update_plot(threshold, prjs_umap, clusters_labels, anomaly_scores, fig_size):
plot_clusters_with_anomalies(prjs_umap, clusters_labels, anomaly_scores, threshold, fig_size)
def plot_clusters_with_anomalies_interactive_plot(threshold, prjs_umap, clusters_labels, anomaly_scores, fig_size):
threshold_slider = widgets.FloatSlider(value=threshold, min=0.001, max=3, step=0.001, description='Threshold')
interactive_plot = widgets.interactive(update_plot, threshold = threshold_slider,
prjs_umap = widgets.fixed(prjs_umap),
clusters_labels = widgets.fixed(clusters_labels),
anomaly_scores = widgets.fixed(anomaly_scores),
fig_size = widgets.fixed((25,25)))
display(interactive_plot)
# %% ../nbs/xai.ipynb 18
import plotly.express as px
from datetime import timedelta
# %% ../nbs/xai.ipynb 19
def get_df_selected(df, selected_indices, w, stride = 1): #Cuidado con stride
'''Links back the selected points to the original dataframe and returns the associated windows indices'''
n_windows = len(selected_indices)
window_ranges = [(id*stride, (id*stride)+w) for id in selected_indices]
#window_ranges = [(id*w, (id+1)*w+1) for id in selected_indices]
#window_ranges = [(id*stride, (id*stride)+w) for id in selected_indices]
#print(window_ranges)
valores_tramos = [df.iloc[inicio:fin+1] for inicio, fin in window_ranges]
df_selected = pd.concat(valores_tramos, ignore_index=False)
return window_ranges, n_windows, df_selected
# %% ../nbs/xai.ipynb 20
def shift_datetime(dt, seconds, sign, dateformat="%Y-%m-%d %H:%M:%S.%f", print_flag = False):
"""
This function gets a datetime dt, a number of seconds,
a sign and moves the date such number of seconds to the future
if sign is '+' and to the past if sing is '-'.
"""
if print_flag: print(dateformat)
dateformat2= "%Y-%m-%d %H:%M:%S.%f"
dateformat3 = "%Y-%m-%d"
ok = False
try:
if print_flag: print("dt ", dt, "seconds", seconds, "sign", sign)
new_dt = datetime.strptime(dt, dateformat)
if print_flag: print("ndt", new_dt)
ok = True
except ValueError as e:
if print_flag:
print("Error: ", e)
if (not ok):
try:
if print_flag: print("Parsing alternative dataformat", dt, "seconds", seconds, "sign", sign, dateformat2)
new_dt = datetime.strptime(dt, dateformat3)
if print_flag: print("2ndt", new_dt)
except ValueError as e:
print("Error: ", e)
if print_flag: print(new_dt)
try:
if new_dt.hour == 0 and new_dt.minute == 0 and new_dt.second == 0:
if print_flag: "Aqui"
new_dt = new_dt.replace(hour=0, minute=0, second=0, microsecond=0)
if print_flag: print(new_dt)
if print_flag: print("ndt", new_dt)
if (sign == '+'):
if print_flag: print("Aqui")
new_dt = new_dt + timedelta(seconds = seconds)
if print_flag: print(new_dt)
else:
if print_flag: print(sign, type(dt))
new_dt = new_dt - timedelta(seconds = seconds)
if print_flag: print(new_dt)
if new_dt.hour == 0 and new_dt.minute == 0 and new_dt.second == 0:
if print_flag: print("replacing")
new_dt = new_dt.replace(hour=0, minute=0, second=0, microsecond=0)
new_dt_str = new_dt.strftime(dateformat2)
if print_flag: print("new dt ", new_dt)
except ValueError as e:
if print_flag: print("Aqui3")
shift_datetime(dt, 0, sign, dateformat = "%Y-%m-%d", print_flag = False)
return str(e)
return new_dt_str
# %% ../nbs/xai.ipynb 21
def get_dateformat(text_date):
dateformat1 = "%Y-%m-%d %H:%M:%S"
dateformat2 = "%Y-%m-%d %H:%M:%S.%f"
dateformat3 = "%Y-%m-%d"
dateformat = ""
parts = text_date.split()
if len(parts) == 2:
time_parts = parts[1].split(':')
if len(time_parts) == 3:
sec_parts = time_parts[2].split('.')
if len(sec_parts) == 2:
dateformat = dateformat2
else:
dateformat = dateformat1
else:
dateformat = "unknown format 1"
elif len(parts) == 1:
dateformat = dateformat3
else:
dateformat = "unknown format 2"
return dateformat
# %% ../nbs/xai.ipynb 23
def get_anomalies(df, threshold, flag):
df['anomaly'] = [ (score > threshold) and flag for score in df['anomaly_score']]
def get_anomaly_styles(df, threshold, anomaly_scores, flag = False, print_flag = False):
if print_flag: print("Threshold: ", threshold)
if print_flag: print("Flag", flag)
if print_flag: print("df ~", df.shape)
df['anomaly'] = [ (score > threshold) and flag for score in df['anomaly_score'] ]
if print_flag: print(df)
get_anomalies(df, threshold, flag)
anomalies = df[df['anomaly']]
if flag:
df['anomaly'] = [
(score > threshold) and flag
for score in anomaly_scores
]
symbols = [
'x' if is_anomaly else 'circle'
for is_anomaly in df['anomaly']
]
line_colors = [
'black'
if (is_anomaly and flag) else 'rgba(0,0,0,0)'
for is_anomaly in df['anomaly']
]
else:
symbols = ['circle' for _ in df['x1']]
line_colors = ['rgba(0,0,0,0)' for _ in df['x1']]
if print_flag: print(anomalies)
return symbols, line_colors
### Example of use
#prjs_df = pd.DataFrame(prjs_umap, columns = ['x1', 'x2'])
#prjs_df['anomaly_score'] = anomaly_scores
#s, l = get_anomaly_styles(prjs_df, 1, True)
# %% ../nbs/xai.ipynb 24
class InteractiveAnomalyPlot():
def __init__(
self, selected_indices = [],
threshold = 0.15,
anomaly_flag = False,
path = "../imgs", w = 0
):
self.selected_indices = selected_indices
self.selected_indices_tmp = selected_indices
self.threshold = threshold
self.threshold_ = threshold
self.anomaly_flag = anomaly_flag
self.w = w
self.name = f"w={self.w}"
self.path = f"{path}{self.name}.png"
self.interaction_enabled = True
def plot_projections_clusters_interactive(
self, prjs, cluster_labels, umap_params, anomaly_scores=[], fig_size=(7,7), print_flag = False
):
self.selected_indices_tmp = self.selected_indices
py.init_notebook_mode()
prjs_df, cluster_colors = plot_initial_config(prjs, cluster_labels, anomaly_scores)
legend_items = [widgets.HTML(f'<b>Cluster {cluster}:</b> <span style="color:{color};">■</span>')
for cluster, color in cluster_colors.items()]
legend = widgets.VBox(legend_items)
marker_colors = prjs_df['cluster'].map(cluster_colors)
symbols, line_colors = get_anomaly_styles(prjs_df, self.threshold_, anomaly_scores, self.anomaly_flag, print_flag)
fig = go.FigureWidget(
[
go.Scatter(
x=prjs_df['x1'], y=prjs_df['x2'],
mode="markers",
marker= {
'color': marker_colors,
'line': { 'color': line_colors, 'width': 1 },
'symbol': symbols
},
text = prjs_df.index
)
]
)
line_trace = go.Scatter(
x=prjs_df['x1'],
y=prjs_df['x2'],
mode="lines",
line=dict(color='rgba(128, 128, 128, 0.5)', width=1)#,
#showlegend=False # Puedes configurar si deseas mostrar esta línea en la leyenda
)
fig.add_trace(line_trace)
sca = fig.data[0]
fig.update_layout(
dragmode='lasso',
width=700,
height=500,
title={
'text': '<span style="font-weight:bold">DR params - n_neighbors:{:d} min_dist:{:f}</span>'.format(
umap_params['n_neighbors'], umap_params['min_dist']),
'y':0.98,
'x':0.5,
'xanchor': 'center',
'yanchor': 'top'
},
plot_bgcolor='white',
paper_bgcolor='#f0f0f0',
xaxis=dict(gridcolor='lightgray', zerolinecolor='black', title = 'x'),
yaxis=dict(gridcolor='lightgray', zerolinecolor='black', title = 'y'),
margin=dict(l=10, r=20, t=30, b=10)
)
output_tmp = Output()
output_button = Output()
output_anomaly = Output()
output_threshold = Output()
output_width = Output()
def select_action(trace, points, selector):
self.selected_indices_tmp = points.point_inds
with output_tmp:
output_tmp.clear_output(wait=True)
if print_flag: print("Selected indices tmp:", self.selected_indices_tmp)
def button_action(b):
self.selected_indices = self.selected_indices_tmp
with output_button:
output_button.clear_output(wait = True)
if print_flag: print("Selected indices:", self.selected_indices)
def update_anomalies():
if print_flag: print("About to update anomalies")
symbols, line_colors = get_anomaly_styles(prjs_df, self.threshold_, anomaly_scores, self.anomaly_flag, print_flag)
if print_flag: print("Anomaly styles got")
with fig.batch_update():
fig.data[0].marker.symbol = symbols
fig.data[0].marker.line.color = line_colors
if print_flag: print("Anomalies updated")
if print_flag: print("Threshold: ", self.threshold_)
if print_flag: print("Scores: ", anomaly_scores)
def anomaly_action(b):
with output_anomaly: # Cambia output_flag a output_anomaly
output_anomaly.clear_output(wait=True)
if print_fllag: print("Negate anomaly flag")
self.anomaly_flag = not self.anomaly_flag
if print_flag: print("Show anomalies:", self.anomaly_flag)
update_anomalies()
sca.on_selection(select_action)
layout = widgets.Layout(width='auto', height='40px')
button = Button(
description="Update selected_indices",
style = {'button_color': 'lightblue'},
display = 'flex',
flex_row = 'column',
align_items = 'stretch',
layout = layout
)
anomaly_button = Button(
description = "Show anomalies",
style = {'button_color': 'lightgray'},
display = 'flex',
flex_row = 'column',
align_items = 'stretch',
layout = layout
)
button.on_click(button_action)
anomaly_button.on_click(anomaly_action)
##### Reactivity buttons
pause_button = Button(
description = "Pause interactiveness",
style = {'button_color': 'pink'},
display = 'flex',
flex_row = 'column',
align_items = 'stretch',
layout = layout
)
resume_button = Button(
description = "Resume interactiveness",
style = {'button_color': 'lightgreen'},
display = 'flex',
flex_row = 'column',
align_items = 'stretch',
layout = layout
)
threshold_slider = FloatSlider(
value=self.threshold_,
min=0.0,
max=float(np.ceil(self.threshold+5)),
step=0.0001,
description='Anomaly threshold:',
continuous_update=False
)
def pause_interaction(b):
self.interaction_enabled = False
fig.update_layout(dragmode='pan')
def resume_interaction(b):
self.interaction_enabled = True
fig.update_layout(dragmode='lasso')
def update_threshold(change):
with output_threshold:
output_threshold.clear_output(wait = True)
if print_flag: print("Update threshold")
self.threshold_ = change.new
if print_flag: print("Update anomalies threshold = ", self.threshold_)
update_anomalies()
#### Width
width_slider = FloatSlider(
value = 0.5,
min = 0.0,
max = 1.0,
step = 0.0001,
description = 'Line width:',
continuous_update = False
)
def update_width(change):
with output_width:
try:
output_width.clear_output(wait = True)
if print_flag:
print("Change line width")
print("Trace to update:", fig.data[1])
with fig.batch_update():
fig.data[1].line.width = change.new # Actualiza la opacidad de la línea
if print_flag: print("ChangeD line width")
except Exception as e:
print("Error updating line width:", e)
pause_button.on_click(pause_interaction)
resume_button.on_click(resume_interaction)
threshold_slider.observe(update_threshold, 'value')
####
width_slider.observe(update_width, names = 'value')
#####
space = HTML("&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;")
vbox = VBox((output_tmp, output_button, output_anomaly, output_threshold, fig))
hbox = HBox((space, button, space, pause_button, space, resume_button, anomaly_button))
# Centrar las dos cajas horizontalmente en el VBox
box_layout = widgets.Layout(display='flex',
flex_flow='column',
align_items='center',
width='100%')
if self.anomaly_flag:
box = VBox((hbox,threshold_slider,width_slider, output_width, vbox), layout = box_layout)
else:
box = VBox((hbox, width_slider, output_width, vbox), layout = box_layout)
box.add_class("layout")
plot_save(fig, self.w)
display(box)
# %% ../nbs/xai.ipynb 25
def plot_save(fig, w):
image_bytes = pio.to_image(fig, format='png')
with open(f"../imgs/w={w}.png", 'wb') as f:
f.write(image_bytes)
# %% ../nbs/xai.ipynb 26
def plot_initial_config(prjs, cluster_labels, anomaly_scores):
prjs_df = pd.DataFrame(prjs, columns = ['x1', 'x2'])
prjs_df['cluster'] = cluster_labels
prjs_df['anomaly_score'] = anomaly_scores
cluster_colors_df = pd.DataFrame({'cluster': cluster_labels}).drop_duplicates()
cluster_colors_df['color'] = px.colors.qualitative.Set1[:len(cluster_colors_df)]
cluster_colors = dict(zip(cluster_colors_df['cluster'], cluster_colors_df['color']))
return prjs_df, cluster_colors
# %% ../nbs/xai.ipynb 27
def merge_overlapping_windows(windows):
if not windows:
return []
# Order
sorted_windows = sorted(windows, key=lambda x: x[0])
merged_windows = [sorted_windows[0]]
for window in sorted_windows[1:]:
if window[0] <= merged_windows[-1][1]:
# Merge!
merged_windows[-1] = (merged_windows[-1][0], max(window[1], merged_windows[-1][1]))
else:
merged_windows.append(window)
return merged_windows
# %% ../nbs/xai.ipynb 29
class InteractiveTSPlot:
def __init__(
self,
df,
selected_indices,
meaningful_features_subset_ids,
w,
stride=1,
print_flag=False,
num_points=10000,
dateformat='%Y-%m-%d %H:%M:%S',
delta_x = 10,
delta_y = 0.1
):
self.df = df
self.selected_indices = selected_indices
self.meaningful_features_subset_ids = meaningful_features_subset_ids
self.w = w
self.stride = stride
self.print_flag = print_flag
self.num_points = num_points
self.dateformat = dateformat
self.fig = go.FigureWidget()
self.buttons = []
self.print_flag = print_flag
self.delta_x = delta_x
self.delta_y = delta_y
self.window_ranges, self.n_windows, self.df_selected = get_df_selected(
self.df, self.selected_indices, self.w, self.stride
)
# Ensure the small possible number of windows to plot (like in R Shiny App)
self.window_ranges = merge_overlapping_windows(self.window_ranges)
#Num points no va bien...
#num_points = min(df_selected.shape[0], num_points)
if self.print_flag:
print("windows: ", self.n_windows, self.window_ranges)
print("selected id: ", self.df_selected.index)
print("points: ", self.num_points)
self.df.index = self.df.index.astype(str)
self.fig = go.FigureWidget()
self.colors = [
f'rgb({np.random.randint(0, 256)}, {np.random.randint(0, 256)}, {np.random.randint(0, 256)})'
for _ in range(self.n_windows)
]
##############################
# Outputs for debug printing #
##############################
self.output_windows = Output()
self.output_move = Output()
self.output_delta_x = Output()
self.output_delta_y = Output()
# %% ../nbs/xai.ipynb 30
def add_selected_features(self: InteractiveTSPlot):
# Add features time series
for feature_id in self.df.columns:
feature_pos = self.df.columns.get_loc(feature_id)
trace = go.Scatter(
#x=df.index[:num_points],
#y=df[feature_id][:num_points],
x = self.df.index,
y = self.df[feature_id],
mode='lines',
name=feature_id,
visible=feature_pos in self.meaningful_features_subset_ids,
text=self.df.index
#text=[f'{i}-{val}' for i, val in enumerate(df.index)]
)
self.fig.add_trace(trace)
InteractiveTSPlot.add_selected_features = add_selected_features
# %% ../nbs/xai.ipynb 31
def add_windows(self: InteractiveTSPlot):
for i, (start, end) in enumerate(self.window_ranges):
self.fig.add_shape(
type="rect",
x0=self.df.index[start],
x1=self.df.index[end],
y0= 0,
y1= 1,
yref = "paper",
fillcolor=self.colors[i], #"LightSalmon",
opacity=0.25,
layer="below",
line=dict(color=self.colors[i], width=1),
name = f"w_{i}"
)
with self.output_windows:
print("w[" + str( self.selected_indices[i] )+ "]="+str(self.df.index[start])+", "+str(self.df.index[end])+")")
InteractiveTSPlot.add_windows = add_windows
# %% ../nbs/xai.ipynb 32
def setup_style(self: InteractiveTSPlot):
self.fig.update_layout(
title='Time Series with time window plot',
xaxis_title='Datetime',
yaxis_title='Value',
legend_title='Variables',
margin=dict(l=10, r=10, t=30, b=10),
xaxis=dict(
tickformat = '%d-' + self.dateformat,
#tickvals=list(range(len(df.index))),
#ticktext = [f'{i}-{val}' for i, val in enumerate(df.index)]
#grid_color = 'lightgray', zerolinecolor='black', title = 'x'
),
#yaxis = dict(grid_color = 'lightgray', zerolinecolor='black', title = 'y'),
#plot_color = 'white',
paper_bgcolor='#f0f0f0'
)
self.fig.update_yaxes(fixedrange=True)
InteractiveTSPlot.setup_style = setup_style
# %% ../nbs/xai.ipynb 34
def toggle_trace(self : InteractiveTSPlot, button : Button):
idx = button.description
trace = self.fig.data[self.df.columns.get_loc(idx)]
trace.visible = not trace.visible
InteractiveTSPlot.toggle_trace = toggle_trace
# %% ../nbs/xai.ipynb 35
def set_features_buttons(self):
self.buttons = [
Button(
description=str(feature_id),
button_style='success' if self.df.columns.get_loc(feature_id) in self.meaningful_features_subset_ids else ''
)
for feature_id in self.df.columns
]
for button in self.buttons:
button.on_click(self.toggle_trace)
InteractiveTSPlot.set_features_buttons = set_features_buttons
# %% ../nbs/xai.ipynb 36
def move_left(self : InteractiveTSPlot, button : Button):
with self.output_move:
self.output_move.clear_output(wait=True)
start_date, end_date = self.fig.layout.xaxis.range
new_start_date = shift_datetime(start_date, self.delta_x, '-', self.dateformat, self.print_flag)
new_end_date = shift_datetime(end_date, self.delta_x, '-', self.dateformat, self.print_flag)
with self.fig.batch_update():
self.fig.layout.xaxis.range = [new_start_date, new_end_date]
def move_right(self : InteractiveTSPlot, button : Button):
self.output_move.clear_output(wait=True)
with self.output_move:
start_date, end_date = self.fig.layout.xaxis.range
new_start_date = shift_datetime(start_date, self.delta_x, '+', self.dateformat, self.print_flag)
new_end_date = shift_datetime(end_date, self.delta_x, '+', self.dateformat, self.print_flag)
with self.fig.batch_update():
self.fig.layout.xaxis.range = [new_start_date, new_end_date]
def move_down(self: InteractiveTSPlot, button : Button):
with self.output_move:
self.output_move.clear_output(wait=True)
start_y, end_y = self.fig.layout.yaxis.range
with self.fig.batch_update():
self.ig.layout.yaxis.range = [start_y-self.delta_y, end_y-self.delta_y]
def move_up(self: InteractiveTSPlot, button : Button):
with self.output_move:
self.output_move.clear_output(wait=True)
start_y, end_y = self.fig.layout.yaxis.range
with self.fig.batch_update():
self.fig.layout.yaxis.range = [start_y+self.delta_y, end_y+self.delta_y]
InteractiveTSPlot.move_left = move_left
InteractiveTSPlot.move_right = move_right
InteractiveTSPlot.move_down = move_down
InteractiveTSPlot.move_up = move_up
# %% ../nbs/xai.ipynb 37
def delta_x_bigger(self: InteractiveTSPlot):
with self.output_delta_x:
self.output_delta_x.clear_output(wait = True)
if self.print_flag: print("Delta before", self.delta_x)
self.delta_x *= 10
if self.print_flag: print("delta_x:", self.delta_x)
def delta_y_bigger(self: InteractiveTSPlot):
with self.output_delta_y:
self.output_delta_y.clear_output(wait = True)
if self.print_flag: print("Delta before", self.delta_y)
self.delta_y *= 10
if self.print_flag: print("delta_y:", self.delta_y)
def delta_x_lower(self:InteractiveTSPlot):
with self.output_delta_x:
self.output_delta_x.clear_output(wait = True)
if self.print_flag: print("Delta before", self.delta_x)
self.delta_x /= 10
if self.print_flag: print("delta_x:", self.delta_x)
def delta_y_lower(self:InteractiveTSPlot):
with self.output_delta_y:
self.output_delta_y.clear_output(wait = True)
print("Delta before", self.delta_y)
self.delta_y = self.delta_y * 10
print("delta_y:", self.delta_y)
InteractiveTSPlot.delta_x_bigger = delta_x_bigger
InteractiveTSPlot.delta_y_bigger = delta_y_bigger
InteractiveTSPlot.delta_x_lower = delta_x_lower
InteractiveTSPlot.delta_y_lower = delta_y_lower
# %% ../nbs/xai.ipynb 38
def add_movement_buttons(self: InteractiveTSPlot):
self.button_left = Button(description="←")
self.button_right = Button(description="→")
self.button_up = Button(description="↑")
self.button_down = Button(description="↓")
self.button_step_x_up = Button(description="dx ↑")
self.button_step_x_down = Button(description="dx ↓")
self.button_step_y_up = Button(description="dy↑")
self.button_step_y_down = Button(description="dy↓")
# TODO: Arreglar que se pueda modificar el paso con el que se avanza. No se ve el output y no se modifica el valor
self.button_step_x_up.on_click(self.delta_x_bigger)
self.button_step_x_down.on_click(self.delta_x_lower)
self.button_step_y_up.on_click(self.delta_y_bigger)
self.button_step_y_down.on_click(self.delta_y_lower)
self.button_left.on_click(self.move_left)
self.button_right.on_click(self.move_right)
self.button_up.on_click(self.move_up)
self.button_down.on_click(self.move_down)
InteractiveTSPlot.add_movement_buttons = add_movement_buttons
# %% ../nbs/xai.ipynb 40
def setup_boxes(self: InteractiveTSPlot):
self.steps_x = VBox([self.button_step_x_up, self.button_step_x_down])
self.steps_y = VBox([self.button_step_y_up, self.button_step_y_down])
arrow_buttons = HBox([self.button_left, self.button_right, self.button_up, self.button_down, self.steps_x, self.steps_y])
hbox_layout = widgets.Layout(display='flex', flex_flow='row wrap', align_items='flex-start')
hbox = HBox(self.buttons, layout=hbox_layout)
box_layout = widgets.Layout(
display='flex',
flex_flow='column',
align_items='center',
width='100%'
)
if self.print_flag:
self.box = VBox([hbox, arrow_buttons, self.output_move, self.output_delta_x, self.output_delta_y, self.fig, self.output_windows], layout=box_layout)
else:
self.box = VBox([hbox, arrow_buttons, self.fig, self.output_windows], layout=box_layout)
InteractiveTSPlot.setup_boxes = setup_boxes
# %% ../nbs/xai.ipynb 41
def initial_plot(self: InteractiveTSPlot):
self.add_selected_features()
self.add_windows()
self.setup_style()
self.set_features_buttons()
self.add_movement_buttons()
self.setup_boxes()
InteractiveTSPlot.initial_plot = initial_plot
# %% ../nbs/xai.ipynb 42
def show(self : InteractiveTSPlot):
self.initial_plot()
display(self.box)
InteractiveTSPlot.show = show