# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/xai.ipynb. # %% auto 0 __all__ = ['get_embeddings', 'get_dataset', 'umap_parameters', 'get_prjs', 'plot_projections', 'plot_projections_clusters', 'calculate_cluster_stats', 'anomaly_score', 'detector', 'plot_anomaly_scores_distribution', 'plot_clusters_with_anomalies', 'update_plot', 'plot_clusters_with_anomalies_interactive_plot', 'get_df_selected', 'shift_datetime', 'get_dateformat', 'get_anomalies', 'get_anomaly_styles', 'InteractiveAnomalyPlot', 'plot_save', 'plot_initial_config', 'merge_overlapping_windows', 'InteractiveTSPlot', 'add_selected_features', 'add_windows', 'setup_style', 'toggle_trace', 'set_features_buttons', 'move_left', 'move_right', 'move_down', 'move_up', 'delta_x_bigger', 'delta_y_bigger', 'delta_x_lower', 'delta_y_lower', 'add_movement_buttons', 'setup_boxes', 'initial_plot', 'show'] # %% ../nbs/xai.ipynb 1 #Weight & Biases import wandb #Yaml from yaml import load, FullLoader #Embeddings from .all import * from tsai.data.preparation import prepare_forecasting_data from tsai.data.validation import get_forecasting_splits from fastcore.all import * #Dimensionality reduction from tsai.imports import * #Clustering import hdbscan import time from .dr import get_PCA_prjs, get_UMAP_prjs, get_TSNE_prjs import seaborn as sns import matplotlib.pyplot as plt import pandas as pd import ipywidgets as widgets from IPython.display import display from functools import partial from IPython.display import display, clear_output, HTML as IPHTML from ipywidgets import Button, Output, VBox, HBox, HTML, Layout, FloatSlider import plotly.graph_objs as go import plotly.offline as py import plotly.io as pio #! pip install kaleido import kaleido # %% ../nbs/xai.ipynb 4 def get_embeddings(config_lrp, run_lrp, api, print_flag = False): artifacts_gettr = run_lrp.use_artifact if config_lrp.use_wandb else api.artifact emb_artifact = artifacts_gettr(config_lrp.emb_artifact, type='embeddings') if print_flag: print(emb_artifact.name) emb_config = emb_artifact.logged_by().config return emb_artifact.to_obj(), emb_artifact, emb_config # %% ../nbs/xai.ipynb 5 def get_dataset( config_lrp, config_emb, config_dr, run_lrp, api, print_flag = False ): # Botch to use artifacts offline artifacts_gettr = run_lrp.use_artifact if config_lrp.use_wandb else api.artifact enc_artifact = artifacts_gettr(config_emb['enc_artifact'], type='learner') if print_flag: print (enc_artifact.name) ## TODO: This only works when you run it two timeS! WTF? try: enc_learner = enc_artifact.to_obj() except: enc_learner = enc_artifact.to_obj() ## Restore artifact enc_logger = enc_artifact.logged_by() enc_artifact_train = artifacts_gettr(enc_logger.config['train_artifact'], type='dataset') #cfg_.show_attrdict(enc_logger.config) if enc_logger.config['valid_artifact'] is not None: enc_artifact_valid = artifacts_gettr(enc_logger.config['valid_artifact'], type='dataset') if print_flag: print("enc_artifact_valid:", enc_artifact_valid.name) if print_flag: print("enc_artifact_train: ", enc_artifact_train.name) if config_dr['dr_artifact'] is not None: print("Is not none") dr_artifact = artifacts_gettr(config_dr['enc_artifact']) else: dr_artifact = enc_artifact_train if print_flag: print("DR artifact train: ", dr_artifact.name) if print_flag: print("--> DR artifact name", dr_artifact.name) dr_artifact df = dr_artifact.to_df() if print_flag: print("--> DR After to df", df.shape) if print_flag: display(df.head()) return df, dr_artifact, enc_artifact, enc_learner # %% ../nbs/xai.ipynb 6 def umap_parameters(config_dr, config): umap_params_cpu = { 'n_neighbors' : config_dr.n_neighbors, 'min_dist' : config_dr.min_dist, 'random_state': np.uint64(822569775), 'metric': config_dr.metric, #'a': 1.5769434601962196, #'b': 0.8950608779914887, #'metric_kwds': {'p': 2}, #No debería ser necesario, just in case #'output_metric': 'euclidean', 'verbose': 4, #'n_epochs': 200 } umap_params_gpu = { 'n_neighbors' : config_dr.n_neighbors, 'min_dist' : config_dr.min_dist, 'random_state': np.uint64(1234), 'metric': config_dr.metric, 'a': 1.5769434601962196, 'b': 0.8950608779914887, 'target_metric': 'euclidean', 'target_n_neighbors': config_dr.n_neighbors, 'verbose': 4, #6, #CUML_LEVEL_TRACE 'n_epochs': 200*3*2, 'init': 'random', 'hash_input': True } if config_dr.cpu_flag: umap_params = umap_params_cpu else: umap_params = umap_params_gpu return umap_params # %% ../nbs/xai.ipynb 7 def get_prjs(embs_no_nan, config_dr, config, print_flag = False): umap_params = umap_parameters(config_dr, config) prjs_pca = get_PCA_prjs( X = embs_no_nan, cpu = False, print_flag = print_flag, **umap_params ) if print_flag: print(prjs_pca.shape) prjs_umap = get_UMAP_prjs( input_data = prjs_pca, cpu = config_dr.cpu_flag, #config_dr.cpu, print_flag = print_flag, **umap_params ) if print_flag: prjs_umap.shape return prjs_umap # %% ../nbs/xai.ipynb 9 def plot_projections(prjs, umap_params, fig_size = (25,25)): "Plot 2D projections thorugh a connected scatter plot" df_prjs = pd.DataFrame(prjs, columns = ['x1', 'x2']) fig = plt.figure(figsize=(fig_size[0],fig_size[1])) ax = fig.add_subplot(111) ax.scatter(df_prjs['x1'], df_prjs['x2'], marker='o', facecolors='none', edgecolors='b', alpha=0.1) ax.plot(df_prjs['x1'], df_prjs['x2'], alpha=0.5, picker=1) plt.title('DR params - n_neighbors:{:d} min_dist:{:f}'.format( umap_params['n_neighbors'],umap_params['min_dist'])) return ax # %% ../nbs/xai.ipynb 10 def plot_projections_clusters(prjs, clusters_labels, umap_params, fig_size = (25,25)): "Plot 2D projections thorugh a connected scatter plot" df_prjs = pd.DataFrame(prjs, columns = ['x1', 'x2']) df_prjs['cluster'] = clusters_labels fig = plt.figure(figsize=(fig_size[0],fig_size[1])) ax = fig.add_subplot(111) # Create a scatter plot for each cluster with different colors unique_labels = df_prjs['cluster'].unique() print(unique_labels) for label in unique_labels: cluster_data = df_prjs[df_prjs['cluster'] == label] ax.scatter(cluster_data['x1'], cluster_data['x2'], label=f'Cluster {label}') #ax.scatter(df_prjs['x1'], df_prjs['x2'], marker='o', facecolors='none', edgecolors='b', alpha=0.1) #ax.plot(df_prjs['x1'], df_prjs['x2'], alpha=0.5, picker=1) plt.title('DR params - n_neighbors:{:d} min_dist:{:f}'.format( umap_params['n_neighbors'],umap_params['min_dist'])) return ax # %% ../nbs/xai.ipynb 11 def calculate_cluster_stats(data, labels): """Computes the media and the standard deviation for every cluster.""" cluster_stats = {} for label in np.unique(labels): #members = data[labels == label] members = data mean = np.mean(members, axis = 0) std = np.std(members, axis = 0) cluster_stats[label] = (mean, std) return cluster_stats # %% ../nbs/xai.ipynb 12 def anomaly_score(point, cluster_stats, label): """Computes an anomaly score for each point.""" mean, std = cluster_stats[label] return np.linalg.norm((point - mean) / std) # %% ../nbs/xai.ipynb 13 def detector(data, labels): """Anomaly detection function.""" cluster_stats = calculate_cluster_stats(data, labels) scores = [] for point, label in zip(data, labels): score = anomaly_score(point, cluster_stats, label) scores.append(score) return np.array(scores) # %% ../nbs/xai.ipynb 15 def plot_anomaly_scores_distribution(anomaly_scores): "Plot the distribution of anomaly scores to check for normality" plt.figure(figsize=(10, 6)) sns.histplot(anomaly_scores, kde=True, bins=30) plt.title("Distribución de Anomaly Scores") plt.xlabel("Anomaly Score") plt.ylabel("Frecuencia") plt.show() # %% ../nbs/xai.ipynb 16 def plot_clusters_with_anomalies(prjs, clusters_labels, anomaly_scores, threshold, fig_size=(25, 25)): "Plot 2D projections of clusters and superimpose anomalies" df_prjs = pd.DataFrame(prjs, columns=['x1', 'x2']) df_prjs['cluster'] = clusters_labels df_prjs['anomaly'] = anomaly_scores > threshold fig = plt.figure(figsize=(fig_size[0], fig_size[1])) ax = fig.add_subplot(111) # Plot each cluster with different colors unique_labels = df_prjs['cluster'].unique() for label in unique_labels: cluster_data = df_prjs[df_prjs['cluster'] == label] ax.scatter(cluster_data['x1'], cluster_data['x2'], label=f'Cluster {label}', alpha=0.7) # Superimpose anomalies anomalies = df_prjs[df_prjs['anomaly']] ax.scatter(anomalies['x1'], anomalies['x2'], color='red', label='Anomalies', edgecolor='k', s=50) plt.title('Clusters and anomalies') plt.legend() plt.show() def update_plot(threshold, prjs_umap, clusters_labels, anomaly_scores, fig_size): plot_clusters_with_anomalies(prjs_umap, clusters_labels, anomaly_scores, threshold, fig_size) def plot_clusters_with_anomalies_interactive_plot(threshold, prjs_umap, clusters_labels, anomaly_scores, fig_size): threshold_slider = widgets.FloatSlider(value=threshold, min=0.001, max=3, step=0.001, description='Threshold') interactive_plot = widgets.interactive(update_plot, threshold = threshold_slider, prjs_umap = widgets.fixed(prjs_umap), clusters_labels = widgets.fixed(clusters_labels), anomaly_scores = widgets.fixed(anomaly_scores), fig_size = widgets.fixed((25,25))) display(interactive_plot) # %% ../nbs/xai.ipynb 18 import plotly.express as px from datetime import timedelta # %% ../nbs/xai.ipynb 19 def get_df_selected(df, selected_indices, w, stride = 1): #Cuidado con stride '''Links back the selected points to the original dataframe and returns the associated windows indices''' n_windows = len(selected_indices) window_ranges = [(id*stride, (id*stride)+w) for id in selected_indices] #window_ranges = [(id*w, (id+1)*w+1) for id in selected_indices] #window_ranges = [(id*stride, (id*stride)+w) for id in selected_indices] #print(window_ranges) valores_tramos = [df.iloc[inicio:fin+1] for inicio, fin in window_ranges] df_selected = pd.concat(valores_tramos, ignore_index=False) return window_ranges, n_windows, df_selected # %% ../nbs/xai.ipynb 20 def shift_datetime(dt, seconds, sign, dateformat="%Y-%m-%d %H:%M:%S.%f", print_flag = False): """ This function gets a datetime dt, a number of seconds, a sign and moves the date such number of seconds to the future if sign is '+' and to the past if sing is '-'. """ if print_flag: print(dateformat) dateformat2= "%Y-%m-%d %H:%M:%S.%f" dateformat3 = "%Y-%m-%d" ok = False try: if print_flag: print("dt ", dt, "seconds", seconds, "sign", sign) new_dt = datetime.strptime(dt, dateformat) if print_flag: print("ndt", new_dt) ok = True except ValueError as e: if print_flag: print("Error: ", e) if (not ok): try: if print_flag: print("Parsing alternative dataformat", dt, "seconds", seconds, "sign", sign, dateformat2) new_dt = datetime.strptime(dt, dateformat3) if print_flag: print("2ndt", new_dt) except ValueError as e: print("Error: ", e) if print_flag: print(new_dt) try: if new_dt.hour == 0 and new_dt.minute == 0 and new_dt.second == 0: if print_flag: "Aqui" new_dt = new_dt.replace(hour=0, minute=0, second=0, microsecond=0) if print_flag: print(new_dt) if print_flag: print("ndt", new_dt) if (sign == '+'): if print_flag: print("Aqui") new_dt = new_dt + timedelta(seconds = seconds) if print_flag: print(new_dt) else: if print_flag: print(sign, type(dt)) new_dt = new_dt - timedelta(seconds = seconds) if print_flag: print(new_dt) if new_dt.hour == 0 and new_dt.minute == 0 and new_dt.second == 0: if print_flag: print("replacing") new_dt = new_dt.replace(hour=0, minute=0, second=0, microsecond=0) new_dt_str = new_dt.strftime(dateformat2) if print_flag: print("new dt ", new_dt) except ValueError as e: if print_flag: print("Aqui3") shift_datetime(dt, 0, sign, dateformat = "%Y-%m-%d", print_flag = False) return str(e) return new_dt_str # %% ../nbs/xai.ipynb 21 def get_dateformat(text_date): dateformat1 = "%Y-%m-%d %H:%M:%S" dateformat2 = "%Y-%m-%d %H:%M:%S.%f" dateformat3 = "%Y-%m-%d" dateformat = "" parts = text_date.split() if len(parts) == 2: time_parts = parts[1].split(':') if len(time_parts) == 3: sec_parts = time_parts[2].split('.') if len(sec_parts) == 2: dateformat = dateformat2 else: dateformat = dateformat1 else: dateformat = "unknown format 1" elif len(parts) == 1: dateformat = dateformat3 else: dateformat = "unknown format 2" return dateformat # %% ../nbs/xai.ipynb 23 def get_anomalies(df, threshold, flag): df['anomaly'] = [ (score > threshold) and flag for score in df['anomaly_score']] def get_anomaly_styles(df, threshold, anomaly_scores, flag = False, print_flag = False): if print_flag: print("Threshold: ", threshold) if print_flag: print("Flag", flag) if print_flag: print("df ~", df.shape) df['anomaly'] = [ (score > threshold) and flag for score in df['anomaly_score'] ] if print_flag: print(df) get_anomalies(df, threshold, flag) anomalies = df[df['anomaly']] if flag: df['anomaly'] = [ (score > threshold) and flag for score in anomaly_scores ] symbols = [ 'x' if is_anomaly else 'circle' for is_anomaly in df['anomaly'] ] line_colors = [ 'black' if (is_anomaly and flag) else 'rgba(0,0,0,0)' for is_anomaly in df['anomaly'] ] else: symbols = ['circle' for _ in df['x1']] line_colors = ['rgba(0,0,0,0)' for _ in df['x1']] if print_flag: print(anomalies) return symbols, line_colors ### Example of use #prjs_df = pd.DataFrame(prjs_umap, columns = ['x1', 'x2']) #prjs_df['anomaly_score'] = anomaly_scores #s, l = get_anomaly_styles(prjs_df, 1, True) # %% ../nbs/xai.ipynb 24 class InteractiveAnomalyPlot(): def __init__( self, selected_indices = [], threshold = 0.15, anomaly_flag = False, path = "../imgs", w = 0 ): self.selected_indices = selected_indices self.selected_indices_tmp = selected_indices self.threshold = threshold self.threshold_ = threshold self.anomaly_flag = anomaly_flag self.w = w self.name = f"w={self.w}" self.path = f"{path}{self.name}.png" self.interaction_enabled = True def plot_projections_clusters_interactive( self, prjs, cluster_labels, umap_params, anomaly_scores=[], fig_size=(7,7), print_flag = False ): self.selected_indices_tmp = self.selected_indices py.init_notebook_mode() prjs_df, cluster_colors = plot_initial_config(prjs, cluster_labels, anomaly_scores) legend_items = [widgets.HTML(f'Cluster {cluster}: ') for cluster, color in cluster_colors.items()] legend = widgets.VBox(legend_items) marker_colors = prjs_df['cluster'].map(cluster_colors) symbols, line_colors = get_anomaly_styles(prjs_df, self.threshold_, anomaly_scores, self.anomaly_flag, print_flag) fig = go.FigureWidget( [ go.Scatter( x=prjs_df['x1'], y=prjs_df['x2'], mode="markers", marker= { 'color': marker_colors, 'line': { 'color': line_colors, 'width': 1 }, 'symbol': symbols }, text = prjs_df.index ) ] ) line_trace = go.Scatter( x=prjs_df['x1'], y=prjs_df['x2'], mode="lines", line=dict(color='rgba(128, 128, 128, 0.5)', width=1)#, #showlegend=False # Puedes configurar si deseas mostrar esta línea en la leyenda ) fig.add_trace(line_trace) sca = fig.data[0] fig.update_layout( dragmode='lasso', width=700, height=500, title={ 'text': 'DR params - n_neighbors:{:d} min_dist:{:f}'.format( umap_params['n_neighbors'], umap_params['min_dist']), 'y':0.98, 'x':0.5, 'xanchor': 'center', 'yanchor': 'top' }, plot_bgcolor='white', paper_bgcolor='#f0f0f0', xaxis=dict(gridcolor='lightgray', zerolinecolor='black', title = 'x'), yaxis=dict(gridcolor='lightgray', zerolinecolor='black', title = 'y'), margin=dict(l=10, r=20, t=30, b=10) ) output_tmp = Output() output_button = Output() output_anomaly = Output() output_threshold = Output() output_width = Output() def select_action(trace, points, selector): self.selected_indices_tmp = points.point_inds with output_tmp: output_tmp.clear_output(wait=True) if print_flag: print("Selected indices tmp:", self.selected_indices_tmp) def button_action(b): self.selected_indices = self.selected_indices_tmp with output_button: output_button.clear_output(wait = True) if print_flag: print("Selected indices:", self.selected_indices) def update_anomalies(): if print_flag: print("About to update anomalies") symbols, line_colors = get_anomaly_styles(prjs_df, self.threshold_, anomaly_scores, self.anomaly_flag, print_flag) if print_flag: print("Anomaly styles got") with fig.batch_update(): fig.data[0].marker.symbol = symbols fig.data[0].marker.line.color = line_colors if print_flag: print("Anomalies updated") if print_flag: print("Threshold: ", self.threshold_) if print_flag: print("Scores: ", anomaly_scores) def anomaly_action(b): with output_anomaly: # Cambia output_flag a output_anomaly output_anomaly.clear_output(wait=True) if print_fllag: print("Negate anomaly flag") self.anomaly_flag = not self.anomaly_flag if print_flag: print("Show anomalies:", self.anomaly_flag) update_anomalies() sca.on_selection(select_action) layout = widgets.Layout(width='auto', height='40px') button = Button( description="Update selected_indices", style = {'button_color': 'lightblue'}, display = 'flex', flex_row = 'column', align_items = 'stretch', layout = layout ) anomaly_button = Button( description = "Show anomalies", style = {'button_color': 'lightgray'}, display = 'flex', flex_row = 'column', align_items = 'stretch', layout = layout ) button.on_click(button_action) anomaly_button.on_click(anomaly_action) ##### Reactivity buttons pause_button = Button( description = "Pause interactiveness", style = {'button_color': 'pink'}, display = 'flex', flex_row = 'column', align_items = 'stretch', layout = layout ) resume_button = Button( description = "Resume interactiveness", style = {'button_color': 'lightgreen'}, display = 'flex', flex_row = 'column', align_items = 'stretch', layout = layout ) threshold_slider = FloatSlider( value=self.threshold_, min=0.0, max=float(np.ceil(self.threshold+5)), step=0.0001, description='Anomaly threshold:', continuous_update=False ) def pause_interaction(b): self.interaction_enabled = False fig.update_layout(dragmode='pan') def resume_interaction(b): self.interaction_enabled = True fig.update_layout(dragmode='lasso') def update_threshold(change): with output_threshold: output_threshold.clear_output(wait = True) if print_flag: print("Update threshold") self.threshold_ = change.new if print_flag: print("Update anomalies threshold = ", self.threshold_) update_anomalies() #### Width width_slider = FloatSlider( value = 0.5, min = 0.0, max = 1.0, step = 0.0001, description = 'Line width:', continuous_update = False ) def update_width(change): with output_width: try: output_width.clear_output(wait = True) if print_flag: print("Change line width") print("Trace to update:", fig.data[1]) with fig.batch_update(): fig.data[1].line.width = change.new # Actualiza la opacidad de la línea if print_flag: print("ChangeD line width") except Exception as e: print("Error updating line width:", e) pause_button.on_click(pause_interaction) resume_button.on_click(resume_interaction) threshold_slider.observe(update_threshold, 'value') #### width_slider.observe(update_width, names = 'value') ##### space = HTML("      ") vbox = VBox((output_tmp, output_button, output_anomaly, output_threshold, fig)) hbox = HBox((space, button, space, pause_button, space, resume_button, anomaly_button)) # Centrar las dos cajas horizontalmente en el VBox box_layout = widgets.Layout(display='flex', flex_flow='column', align_items='center', width='100%') if self.anomaly_flag: box = VBox((hbox,threshold_slider,width_slider, output_width, vbox), layout = box_layout) else: box = VBox((hbox, width_slider, output_width, vbox), layout = box_layout) box.add_class("layout") plot_save(fig, self.w) display(box) # %% ../nbs/xai.ipynb 25 def plot_save(fig, w): image_bytes = pio.to_image(fig, format='png') with open(f"../imgs/w={w}.png", 'wb') as f: f.write(image_bytes) # %% ../nbs/xai.ipynb 26 def plot_initial_config(prjs, cluster_labels, anomaly_scores): prjs_df = pd.DataFrame(prjs, columns = ['x1', 'x2']) prjs_df['cluster'] = cluster_labels prjs_df['anomaly_score'] = anomaly_scores cluster_colors_df = pd.DataFrame({'cluster': cluster_labels}).drop_duplicates() cluster_colors_df['color'] = px.colors.qualitative.Set1[:len(cluster_colors_df)] cluster_colors = dict(zip(cluster_colors_df['cluster'], cluster_colors_df['color'])) return prjs_df, cluster_colors # %% ../nbs/xai.ipynb 27 def merge_overlapping_windows(windows): if not windows: return [] # Order sorted_windows = sorted(windows, key=lambda x: x[0]) merged_windows = [sorted_windows[0]] for window in sorted_windows[1:]: if window[0] <= merged_windows[-1][1]: # Merge! merged_windows[-1] = (merged_windows[-1][0], max(window[1], merged_windows[-1][1])) else: merged_windows.append(window) return merged_windows # %% ../nbs/xai.ipynb 29 class InteractiveTSPlot: def __init__( self, df, selected_indices, meaningful_features_subset_ids, w, stride=1, print_flag=False, num_points=10000, dateformat='%Y-%m-%d %H:%M:%S', delta_x = 10, delta_y = 0.1 ): self.df = df self.selected_indices = selected_indices self.meaningful_features_subset_ids = meaningful_features_subset_ids self.w = w self.stride = stride self.print_flag = print_flag self.num_points = num_points self.dateformat = dateformat self.fig = go.FigureWidget() self.buttons = [] self.print_flag = print_flag self.delta_x = delta_x self.delta_y = delta_y self.window_ranges, self.n_windows, self.df_selected = get_df_selected( self.df, self.selected_indices, self.w, self.stride ) # Ensure the small possible number of windows to plot (like in R Shiny App) self.window_ranges = merge_overlapping_windows(self.window_ranges) #Num points no va bien... #num_points = min(df_selected.shape[0], num_points) if self.print_flag: print("windows: ", self.n_windows, self.window_ranges) print("selected id: ", self.df_selected.index) print("points: ", self.num_points) self.df.index = self.df.index.astype(str) self.fig = go.FigureWidget() self.colors = [ f'rgb({np.random.randint(0, 256)}, {np.random.randint(0, 256)}, {np.random.randint(0, 256)})' for _ in range(self.n_windows) ] ############################## # Outputs for debug printing # ############################## self.output_windows = Output() self.output_move = Output() self.output_delta_x = Output() self.output_delta_y = Output() # %% ../nbs/xai.ipynb 30 def add_selected_features(self: InteractiveTSPlot): # Add features time series for feature_id in self.df.columns: feature_pos = self.df.columns.get_loc(feature_id) trace = go.Scatter( #x=df.index[:num_points], #y=df[feature_id][:num_points], x = self.df.index, y = self.df[feature_id], mode='lines', name=feature_id, visible=feature_pos in self.meaningful_features_subset_ids, text=self.df.index #text=[f'{i}-{val}' for i, val in enumerate(df.index)] ) self.fig.add_trace(trace) InteractiveTSPlot.add_selected_features = add_selected_features # %% ../nbs/xai.ipynb 31 def add_windows(self: InteractiveTSPlot): for i, (start, end) in enumerate(self.window_ranges): self.fig.add_shape( type="rect", x0=self.df.index[start], x1=self.df.index[end], y0= 0, y1= 1, yref = "paper", fillcolor=self.colors[i], #"LightSalmon", opacity=0.25, layer="below", line=dict(color=self.colors[i], width=1), name = f"w_{i}" ) with self.output_windows: print("w[" + str( self.selected_indices[i] )+ "]="+str(self.df.index[start])+", "+str(self.df.index[end])+")") InteractiveTSPlot.add_windows = add_windows # %% ../nbs/xai.ipynb 32 def setup_style(self: InteractiveTSPlot): self.fig.update_layout( title='Time Series with time window plot', xaxis_title='Datetime', yaxis_title='Value', legend_title='Variables', margin=dict(l=10, r=10, t=30, b=10), xaxis=dict( tickformat = '%d-' + self.dateformat, #tickvals=list(range(len(df.index))), #ticktext = [f'{i}-{val}' for i, val in enumerate(df.index)] #grid_color = 'lightgray', zerolinecolor='black', title = 'x' ), #yaxis = dict(grid_color = 'lightgray', zerolinecolor='black', title = 'y'), #plot_color = 'white', paper_bgcolor='#f0f0f0' ) self.fig.update_yaxes(fixedrange=True) InteractiveTSPlot.setup_style = setup_style # %% ../nbs/xai.ipynb 34 def toggle_trace(self : InteractiveTSPlot, button : Button): idx = button.description trace = self.fig.data[self.df.columns.get_loc(idx)] trace.visible = not trace.visible InteractiveTSPlot.toggle_trace = toggle_trace # %% ../nbs/xai.ipynb 35 def set_features_buttons(self): self.buttons = [ Button( description=str(feature_id), button_style='success' if self.df.columns.get_loc(feature_id) in self.meaningful_features_subset_ids else '' ) for feature_id in self.df.columns ] for button in self.buttons: button.on_click(self.toggle_trace) InteractiveTSPlot.set_features_buttons = set_features_buttons # %% ../nbs/xai.ipynb 36 def move_left(self : InteractiveTSPlot, button : Button): with self.output_move: self.output_move.clear_output(wait=True) start_date, end_date = self.fig.layout.xaxis.range new_start_date = shift_datetime(start_date, self.delta_x, '-', self.dateformat, self.print_flag) new_end_date = shift_datetime(end_date, self.delta_x, '-', self.dateformat, self.print_flag) with self.fig.batch_update(): self.fig.layout.xaxis.range = [new_start_date, new_end_date] def move_right(self : InteractiveTSPlot, button : Button): self.output_move.clear_output(wait=True) with self.output_move: start_date, end_date = self.fig.layout.xaxis.range new_start_date = shift_datetime(start_date, self.delta_x, '+', self.dateformat, self.print_flag) new_end_date = shift_datetime(end_date, self.delta_x, '+', self.dateformat, self.print_flag) with self.fig.batch_update(): self.fig.layout.xaxis.range = [new_start_date, new_end_date] def move_down(self: InteractiveTSPlot, button : Button): with self.output_move: self.output_move.clear_output(wait=True) start_y, end_y = self.fig.layout.yaxis.range with self.fig.batch_update(): self.ig.layout.yaxis.range = [start_y-self.delta_y, end_y-self.delta_y] def move_up(self: InteractiveTSPlot, button : Button): with self.output_move: self.output_move.clear_output(wait=True) start_y, end_y = self.fig.layout.yaxis.range with self.fig.batch_update(): self.fig.layout.yaxis.range = [start_y+self.delta_y, end_y+self.delta_y] InteractiveTSPlot.move_left = move_left InteractiveTSPlot.move_right = move_right InteractiveTSPlot.move_down = move_down InteractiveTSPlot.move_up = move_up # %% ../nbs/xai.ipynb 37 def delta_x_bigger(self: InteractiveTSPlot): with self.output_delta_x: self.output_delta_x.clear_output(wait = True) if self.print_flag: print("Delta before", self.delta_x) self.delta_x *= 10 if self.print_flag: print("delta_x:", self.delta_x) def delta_y_bigger(self: InteractiveTSPlot): with self.output_delta_y: self.output_delta_y.clear_output(wait = True) if self.print_flag: print("Delta before", self.delta_y) self.delta_y *= 10 if self.print_flag: print("delta_y:", self.delta_y) def delta_x_lower(self:InteractiveTSPlot): with self.output_delta_x: self.output_delta_x.clear_output(wait = True) if self.print_flag: print("Delta before", self.delta_x) self.delta_x /= 10 if self.print_flag: print("delta_x:", self.delta_x) def delta_y_lower(self:InteractiveTSPlot): with self.output_delta_y: self.output_delta_y.clear_output(wait = True) print("Delta before", self.delta_y) self.delta_y = self.delta_y * 10 print("delta_y:", self.delta_y) InteractiveTSPlot.delta_x_bigger = delta_x_bigger InteractiveTSPlot.delta_y_bigger = delta_y_bigger InteractiveTSPlot.delta_x_lower = delta_x_lower InteractiveTSPlot.delta_y_lower = delta_y_lower # %% ../nbs/xai.ipynb 38 def add_movement_buttons(self: InteractiveTSPlot): self.button_left = Button(description="←") self.button_right = Button(description="→") self.button_up = Button(description="↑") self.button_down = Button(description="↓") self.button_step_x_up = Button(description="dx ↑") self.button_step_x_down = Button(description="dx ↓") self.button_step_y_up = Button(description="dy↑") self.button_step_y_down = Button(description="dy↓") # TODO: Arreglar que se pueda modificar el paso con el que se avanza. No se ve el output y no se modifica el valor self.button_step_x_up.on_click(self.delta_x_bigger) self.button_step_x_down.on_click(self.delta_x_lower) self.button_step_y_up.on_click(self.delta_y_bigger) self.button_step_y_down.on_click(self.delta_y_lower) self.button_left.on_click(self.move_left) self.button_right.on_click(self.move_right) self.button_up.on_click(self.move_up) self.button_down.on_click(self.move_down) InteractiveTSPlot.add_movement_buttons = add_movement_buttons # %% ../nbs/xai.ipynb 40 def setup_boxes(self: InteractiveTSPlot): self.steps_x = VBox([self.button_step_x_up, self.button_step_x_down]) self.steps_y = VBox([self.button_step_y_up, self.button_step_y_down]) arrow_buttons = HBox([self.button_left, self.button_right, self.button_up, self.button_down, self.steps_x, self.steps_y]) hbox_layout = widgets.Layout(display='flex', flex_flow='row wrap', align_items='flex-start') hbox = HBox(self.buttons, layout=hbox_layout) box_layout = widgets.Layout( display='flex', flex_flow='column', align_items='center', width='100%' ) if self.print_flag: self.box = VBox([hbox, arrow_buttons, self.output_move, self.output_delta_x, self.output_delta_y, self.fig, self.output_windows], layout=box_layout) else: self.box = VBox([hbox, arrow_buttons, self.fig, self.output_windows], layout=box_layout) InteractiveTSPlot.setup_boxes = setup_boxes # %% ../nbs/xai.ipynb 41 def initial_plot(self: InteractiveTSPlot): self.add_selected_features() self.add_windows() self.setup_style() self.set_features_buttons() self.add_movement_buttons() self.setup_boxes() InteractiveTSPlot.initial_plot = initial_plot # %% ../nbs/xai.ipynb 42 def show(self : InteractiveTSPlot): self.initial_plot() display(self.box) InteractiveTSPlot.show = show