deepvats / dvats_xai /utils.py
misantamaria's picture
trying to fix cuda error
6d51833
raw
history blame
9.42 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/utils.ipynb.
# %% auto 0
__all__ = ['generate_TS_df', 'normalize_columns', 'remove_constant_columns', 'ReferenceArtifact', 'PrintLayer',
'get_wandb_artifacts', 'get_pickle_artifact', 'exec_with_feather', 'py_function',
'exec_with_feather_k_output', 'exec_with_and_feather_k_output', 'learner_module_leaves',
'learner_module_leaves_subtables']
# %% ../nbs/utils.ipynb 3
from .imports import *
from fastcore.all import *
import wandb
import pickle
import pandas as pd
import numpy as np
#import tensorflow as tf
import torch.nn as nn
from fastai.basics import *
# %% ../nbs/utils.ipynb 5
def generate_TS_df(rows, cols):
"Generates a dataframe containing a multivariate time series, where each column \
represents a variable and each row a time point (sample). The timestamp is in the \
index of the dataframe, and it is created with a even space of 1 second between samples"
index = np.arange(pd.Timestamp.now(),
pd.Timestamp.now() + pd.Timedelta(rows-1, 'seconds'),
pd.Timedelta(1, 'seconds'))
data = np.random.randn(len(index), cols)
return pd.DataFrame(data, index=index)
# %% ../nbs/utils.ipynb 10
def normalize_columns(df:pd.DataFrame):
"Normalize columns from `df` to have 0 mean and 1 standard deviation"
mean = df.mean()
std = df.std() + 1e-7
return (df-mean)/std
# %% ../nbs/utils.ipynb 16
def remove_constant_columns(df:pd.DataFrame):
return df.loc[:, (df != df.iloc[0]).any()]
# %% ../nbs/utils.ipynb 21
class ReferenceArtifact(wandb.Artifact):
default_storage_path = Path('data/wandb_artifacts/') # * this path is relative to Path.home()
"This class is meant to create an artifact with a single reference to an object \
passed as argument in the contructor. The object will be pickled, hashed and stored \
in a specified folder."
@delegates(wandb.Artifact.__init__)
def __init__(self, obj, name, type='object', folder=None, **kwargs):
super().__init__(type=type, name=name, **kwargs)
# pickle dumps the object and then hash it
hash_code = str(hash(pickle.dumps(obj)))
folder = Path(ifnone(folder, Path.home()/self.default_storage_path))
with open(f'{folder}/{hash_code}', 'wb') as f:
pickle.dump(obj, f)
self.add_reference(f'file://{folder}/{hash_code}')
if self.metadata is None:
self.metadata = dict()
self.metadata['ref'] = dict()
self.metadata['ref']['hash'] = hash_code
self.metadata['ref']['type'] = str(obj.__class__)
# %% ../nbs/utils.ipynb 24
@patch
def to_obj(self:wandb.apis.public.Artifact):
"""Download the files of a saved ReferenceArtifact and get the referenced object. The artifact must \
come from a call to `run.use_artifact` with a proper wandb run."""
if self.metadata.get('ref') is None:
print(f'ERROR:{self} does not come from a saved ReferenceArtifact')
return None
original_path = ReferenceArtifact.default_storage_path/self.metadata['ref']['hash']
path = original_path if original_path.exists() else Path(self.download()).ls()[0]
with open(path, 'rb') as f:
obj = pickle.load(f)
return obj
# %% ../nbs/utils.ipynb 33
import torch.nn as nn
class PrintLayer(nn.Module):
def __init__(self):
super(PrintLayer, self).__init__()
def forward(self, x):
# Do your print / debug stuff here
print(x.shape)
return x
# %% ../nbs/utils.ipynb 34
@patch
def export_and_get(self:Learner, keep_exported_file=False):
"""
Export the learner into an auxiliary file, load it and return it back.
"""
aux_path = Path('aux.pkl')
self.export(fname='aux.pkl')
aux_learn = load_learner('aux.pkl')
if not keep_exported_file: aux_path.unlink()
return aux_learn
# %% ../nbs/utils.ipynb 35
def get_wandb_artifacts(project_path, type=None, name=None, last_version=True):
"""
Get the artifacts logged in a wandb project.
Input:
- `project_path` (str): entity/project_name
- `type` (str): whether to return only one type of artifacts
- `name` (str): Leave none to have all artifact names
- `last_version`: whether to return only the last version of each artifact or not
Output: List of artifacts
"""
public_api = wandb.Api()
if type is not None:
types = [public_api.artifact_type(type, project_path)]
else:
types = public_api.artifact_types(project_path)
res = L()
for kind in types:
for collection in kind.collections():
if name is None or name == collection.name:
versions = public_api.artifact_versions(
kind.type,
"/".join([kind.entity, kind.project, collection.name]),
per_page=1,
)
if last_version: res += next(versions)
else: res += L(versions)
return list(res)
# %% ../nbs/utils.ipynb 39
def get_pickle_artifact(filename):
with open(filename, "rb") as f:
df = pickle.load(f)
return df
# %% ../nbs/utils.ipynb 41
import pyarrow.feather as ft
import pickle
# %% ../nbs/utils.ipynb 42
def exec_with_feather(function, path = None, print_flag = False, *args, **kwargs):
result = None
if not (path is none):
if print_flag: print("--> Exec with feather | reading input from ", path)
input = ft.read_feather(path)
if print_flag: print("--> Exec with feather | Apply function ", path)
result = function(input, *args, **kwargs)
if print_flag: print("Exec with feather --> ", path)
return result
# %% ../nbs/utils.ipynb 43
def py_function(module_name, function_name, print_flag = False):
try:
function = getattr(__import__('__main__'), function_name)
except:
module = __import__(module_name, fromlist=[''])
function = getattr(module, function_name)
print("py function: ", function_name, ": ", function)
return function
# %% ../nbs/utils.ipynb 46
import time
def exec_with_feather_k_output(function_name, module_name = "main", path = None, k_output = 0, print_flag = False, time_flag = False, *args, **kwargs):
result = None
function = py_function(module_name, function_name, print_flag)
if time_flag: t_start = time.time()
if not (path is None):
if print_flag: print("--> Exec with feather | reading input from ", path)
input = ft.read_feather(path)
if print_flag: print("--> Exec with feather | Apply function ", path)
result = function(input, *args, **kwargs)[k_output]
if time_flag:
t_end = time.time()
print("Exec with feather | time: ", t_end-t_start)
if print_flag: print("Exec with feather --> ", path)
return result
# %% ../nbs/utils.ipynb 48
def exec_with_and_feather_k_output(function_name, module_name = "main", path_input = None, path_output = None, k_output = 0, print_flag = False, time_flag = False, *args, **kwargs):
result = None
function = py_function(module_name, function_name, print_flag)
if time_flag: t_start = time.time()
if not (path_input is None):
if print_flag: print("--> Exec with feather | reading input from ", path_input)
input = ft.read_feather(path_input)
if print_flag:
print("--> Exec with feather | Apply function ", function_name, "input type: ", type(input))
result = function(input, *args, **kwargs)[k_output]
ft.write_feather(df, path, compression = 'lz4')
if time_flag:
t_end = time.time()
print("Exec with feather | time: ", t_end-t_start)
if print_flag: print("Exec with feather --> ", path_output)
return path_output
# %% ../nbs/utils.ipynb 52
def learner_module_leaves(learner):
modules = list(learner.modules())[0] # Obtener el módulo raíz
rows = []
def find_leave_modules(module, path=[]):
for name, sub_module in module.named_children():
current_path = path + [f"{type(sub_module).__name__}"]
if not list(sub_module.children()):
leave_name = ' -> '.join(current_path)
leave_params = str(sub_module).strip()
rows.append([
leave_name,
f"{type(sub_module).__name__}",
name,
leave_params
]
)
find_leave_modules(sub_module, current_path)
find_leave_modules(modules)
df = pd.DataFrame(rows, columns=['Path', 'Module_type', 'Module_name', 'Module'])
return df
# %% ../nbs/utils.ipynb 56
def learner_module_leaves_subtables(learner, print_flag = False):
df = pd.DataFrame(columns=['Path', 'Module_type', 'Module_name', 'Module'])
md = learner_module_leaves(learner).drop(
'Path', axis = 1
).sort_values(
by = 'Module_type'
)
if print_flag: print("The layers are of this types:")
md_types = pd.DataFrame(md['Module_type'].drop_duplicates())
if print_flag:
display(md_types)
print("And they are called with this parameters:")
md_modules = pd.DataFrame(md['Module'].drop_duplicates())
if print_flag: display(md_modules)
return md_types, md_modules