# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/utils.ipynb (unless otherwise specified). __all__ = ['generate_TS_df', 'normalize_columns', 'remove_constant_columns', 'ReferenceArtifact', 'PrintLayer', 'get_wandb_artifacts', 'get_pickle_artifact'] # Cell from .imports import * from fastcore.all import * import wandb import pickle import pandas as pd import numpy as np #import tensorflow as tf import torch.nn as nn from fastai.basics import * # Cell def generate_TS_df(rows, cols): "Generates a dataframe containing a multivariate time series, where each column \ represents a variable and each row a time point (sample). The timestamp is in the \ index of the dataframe, and it is created with a even space of 1 second between samples" index = np.arange(pd.Timestamp.now(), pd.Timestamp.now() + pd.Timedelta(rows-1, 'seconds'), pd.Timedelta(1, 'seconds')) data = np.random.randn(len(index), cols) return pd.DataFrame(data, index=index) # Cell def normalize_columns(df:pd.DataFrame): "Normalize columns from `df` to have 0 mean and 1 standard deviation" mean = df.mean() std = df.std() + 1e-7 return (df-mean)/std # Cell def remove_constant_columns(df:pd.DataFrame): return df.loc[:, (df != df.iloc[0]).any()] # Cell class ReferenceArtifact(wandb.Artifact): default_storage_path = Path('data/wandb_artifacts/') # * this path is relative to Path.home() "This class is meant to create an artifact with a single reference to an object \ passed as argument in the contructor. The object will be pickled, hashed and stored \ in a specified folder." @delegates(wandb.Artifact.__init__) def __init__(self, obj, name, type='object', folder=None, **kwargs): super().__init__(type=type, name=name, **kwargs) # pickle dumps the object and then hash it hash_code = str(hash(pickle.dumps(obj))) folder = Path(ifnone(folder, Path.home()/self.default_storage_path)) with open(f'{folder}/{hash_code}', 'wb') as f: pickle.dump(obj, f) self.add_reference(f'file://{folder}/{hash_code}') if self.metadata is None: self.metadata = dict() self.metadata['ref'] = dict() self.metadata['ref']['hash'] = hash_code self.metadata['ref']['type'] = str(obj.__class__) # Cell @patch def to_obj(self:wandb.apis.public.Artifact): """Download the files of a saved ReferenceArtifact and get the referenced object. The artifact must \ come from a call to `run.use_artifact` with a proper wandb run.""" if self.metadata.get('ref') is None: print(f'ERROR:{self} does not come from a saved ReferenceArtifact') return None original_path = ReferenceArtifact.default_storage_path/self.metadata['ref']['hash'] path = original_path if original_path.exists() else Path(self.download()).ls()[0] with open(path, 'rb') as f: obj = pickle.load(f) return obj # Cell import torch.nn as nn class PrintLayer(nn.Module): def __init__(self): super(PrintLayer, self).__init__() def forward(self, x): # Do your print / debug stuff here print(x.shape) return x # Cell @patch def export_and_get(self:Learner, keep_exported_file=False): """ Export the learner into an auxiliary file, load it and return it back. """ aux_path = Path('aux.pkl') self.export(fname='aux.pkl') aux_learn = load_learner('aux.pkl') if not keep_exported_file: aux_path.unlink() return aux_learn # Cell def get_wandb_artifacts(project_path, type=None, name=None, last_version=True): """ Get the artifacts logged in a wandb project. Input: - `project_path` (str): entity/project_name - `type` (str): whether to return only one type of artifacts - `name` (str): Leave none to have all artifact names - `last_version`: whether to return only the last version of each artifact or not Output: List of artifacts """ public_api = wandb.Api() if type is not None: types = [public_api.artifact_type(type, project_path)] else: types = public_api.artifact_types(project_path) res = L() for kind in types: for collection in kind.collections(): if name is None or name == collection.name: versions = public_api.artifact_versions( kind.type, "/".join([kind.entity, kind.project, collection.name]), per_page=1, ) if last_version: res += next(versions) else: res += L(versions) return list(res) # Cell def get_pickle_artifact(filename): with open(filename, "rb") as f: df = pickle.load(f) return df