Spaces:
Sleeping
Sleeping
File size: 4,835 Bytes
7399708 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/utils.ipynb (unless otherwise specified).
__all__ = ['generate_TS_df', 'normalize_columns', 'remove_constant_columns', 'ReferenceArtifact', 'PrintLayer',
'get_wandb_artifacts', 'get_pickle_artifact']
# Cell
from .imports import *
from fastcore.all import *
import wandb
import pickle
import pandas as pd
import numpy as np
#import tensorflow as tf
import torch.nn as nn
from fastai.basics import *
# Cell
def generate_TS_df(rows, cols):
"Generates a dataframe containing a multivariate time series, where each column \
represents a variable and each row a time point (sample). The timestamp is in the \
index of the dataframe, and it is created with a even space of 1 second between samples"
index = np.arange(pd.Timestamp.now(),
pd.Timestamp.now() + pd.Timedelta(rows-1, 'seconds'),
pd.Timedelta(1, 'seconds'))
data = np.random.randn(len(index), cols)
return pd.DataFrame(data, index=index)
# Cell
def normalize_columns(df:pd.DataFrame):
"Normalize columns from `df` to have 0 mean and 1 standard deviation"
mean = df.mean()
std = df.std() + 1e-7
return (df-mean)/std
# Cell
def remove_constant_columns(df:pd.DataFrame):
return df.loc[:, (df != df.iloc[0]).any()]
# Cell
class ReferenceArtifact(wandb.Artifact):
default_storage_path = Path('data/wandb_artifacts/') # * this path is relative to Path.home()
"This class is meant to create an artifact with a single reference to an object \
passed as argument in the contructor. The object will be pickled, hashed and stored \
in a specified folder."
@delegates(wandb.Artifact.__init__)
def __init__(self, obj, name, type='object', folder=None, **kwargs):
super().__init__(type=type, name=name, **kwargs)
# pickle dumps the object and then hash it
hash_code = str(hash(pickle.dumps(obj)))
folder = Path(ifnone(folder, Path.home()/self.default_storage_path))
with open(f'{folder}/{hash_code}', 'wb') as f:
pickle.dump(obj, f)
self.add_reference(f'file://{folder}/{hash_code}')
if self.metadata is None:
self.metadata = dict()
self.metadata['ref'] = dict()
self.metadata['ref']['hash'] = hash_code
self.metadata['ref']['type'] = str(obj.__class__)
# Cell
@patch
def to_obj(self:wandb.apis.public.Artifact):
"""Download the files of a saved ReferenceArtifact and get the referenced object. The artifact must \
come from a call to `run.use_artifact` with a proper wandb run."""
if self.metadata.get('ref') is None:
print(f'ERROR:{self} does not come from a saved ReferenceArtifact')
return None
original_path = ReferenceArtifact.default_storage_path/self.metadata['ref']['hash']
path = original_path if original_path.exists() else Path(self.download()).ls()[0]
with open(path, 'rb') as f:
obj = pickle.load(f)
return obj
# Cell
import torch.nn as nn
class PrintLayer(nn.Module):
def __init__(self):
super(PrintLayer, self).__init__()
def forward(self, x):
# Do your print / debug stuff here
print(x.shape)
return x
# Cell
@patch
def export_and_get(self:Learner, keep_exported_file=False):
"""
Export the learner into an auxiliary file, load it and return it back.
"""
aux_path = Path('aux.pkl')
self.export(fname='aux.pkl')
aux_learn = load_learner('aux.pkl')
if not keep_exported_file: aux_path.unlink()
return aux_learn
# Cell
def get_wandb_artifacts(project_path, type=None, name=None, last_version=True):
"""
Get the artifacts logged in a wandb project.
Input:
- `project_path` (str): entity/project_name
- `type` (str): whether to return only one type of artifacts
- `name` (str): Leave none to have all artifact names
- `last_version`: whether to return only the last version of each artifact or not
Output: List of artifacts
"""
public_api = wandb.Api()
if type is not None:
types = [public_api.artifact_type(type, project_path)]
else:
types = public_api.artifact_types(project_path)
res = L()
for kind in types:
for collection in kind.collections():
if name is None or name == collection.name:
versions = public_api.artifact_versions(
kind.type,
"/".join([kind.entity, kind.project, collection.name]),
per_page=1,
)
if last_version: res += next(versions)
else: res += L(versions)
return list(res)
# Cell
def get_pickle_artifact(filename):
with open(filename, "rb") as f:
df = pickle.load(f)
return df |