Spaces:
Build error
Build error
import argparse | |
from typing import Any | |
import tensorflow as tf | |
class EasyDict(dict): | |
def __getattr__(self, name: str) -> Any: | |
try: | |
return self[name] | |
except KeyError: | |
raise AttributeError(name) | |
def __setattr__(self, name: str, value: Any) -> None: | |
self[name] = value | |
def __delattr__(self, name: str) -> None: | |
del self[name] | |
def params_args(args): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--hop", | |
type=int, | |
default=256, | |
help="Hop size (window size = 4*hop)", | |
) | |
parser.add_argument( | |
"--mel_bins", | |
type=int, | |
default=256, | |
help="Mel bins in mel-spectrograms", | |
) | |
parser.add_argument( | |
"--sr", | |
type=int, | |
default=22050, | |
help="Sampling Rate", | |
) | |
parser.add_argument( | |
"--latlen", | |
type=int, | |
default=256, | |
help="Length of generated latent vectors", | |
) | |
parser.add_argument( | |
"--latdepth", | |
type=int, | |
default=64, | |
help="Depth of generated latent vectors", | |
) | |
parser.add_argument( | |
"--shape", | |
type=int, | |
default=128, | |
help="Length of spectrograms time axis", | |
) | |
parser.add_argument( | |
"--window", | |
type=int, | |
default=64, | |
help="Generator spectrogram window (must divide shape)", | |
) | |
parser.add_argument( | |
"--mu_rescale", | |
type=int, | |
default=-25.0, | |
help="Spectrogram mu used to normalize", | |
) | |
parser.add_argument( | |
"--sigma_rescale", | |
type=int, | |
default=75.0, | |
help="Spectrogram sigma used to normalize", | |
) | |
parser.add_argument( | |
"--load_path_techno", | |
type=str, | |
default="checkpoints/techno/", | |
help="Path of pretrained networks weights (techno)", | |
) | |
parser.add_argument( | |
"--load_path_classical", | |
type=str, | |
default="checkpoints/classical/", | |
help="Path of pretrained networks weights (classical)", | |
) | |
parser.add_argument( | |
"--dec_path_techno", | |
type=str, | |
default="checkpoints/techno/", | |
help="Path of pretrained decoders weights (techno)", | |
) | |
parser.add_argument( | |
"--dec_path_classical", | |
type=str, | |
default="checkpoints/classical/", | |
help="Path of pretrained decoders weights (classical)", | |
) | |
parser.add_argument( | |
"--testing", | |
type=bool, | |
default=True, | |
help="True if optimizers weight do not need to be loaded", | |
) | |
parser.add_argument( | |
"--cpu", | |
type=bool, | |
default=False, | |
help="True if you wish to use cpu", | |
) | |
parser.add_argument( | |
"--mixed_precision", | |
type=bool, | |
default=True, | |
help="True if your GPU supports mixed precision", | |
) | |
tmp_args = parser.parse_args() | |
args.hop = tmp_args.hop | |
args.mel_bins = tmp_args.mel_bins | |
args.sr = tmp_args.sr | |
args.latlen = tmp_args.latlen | |
args.latdepth = tmp_args.latdepth | |
args.shape = tmp_args.shape | |
args.window = tmp_args.window | |
args.mu_rescale = tmp_args.mu_rescale | |
args.sigma_rescale = tmp_args.sigma_rescale | |
args.load_path_techno = tmp_args.load_path_techno | |
args.load_path_classical = tmp_args.load_path_classical | |
args.dec_path_techno = tmp_args.dec_path_techno | |
args.dec_path_classical = tmp_args.dec_path_classical | |
args.testing = tmp_args.testing | |
args.cpu = tmp_args.cpu | |
args.mixed_precision = tmp_args.mixed_precision | |
print() | |
args.datatype = tf.float32 | |
gpuls = tf.config.list_physical_devices("GPU") | |
if len(gpuls) == 0 or args.cpu: | |
args.cpu = True | |
args.mixed_precision = False | |
tf.config.set_visible_devices([], "GPU") | |
print() | |
print("Using CPU...") | |
print() | |
if args.mixed_precision: | |
args.datatype = tf.float16 | |
print() | |
print("Using GPU with mixed precision enabled...") | |
print() | |
if not args.mixed_precision and not args.cpu: | |
print() | |
print("Using GPU without mixed precision...") | |
print() | |
return args | |
def parse_args(): | |
args = EasyDict() | |
return params_args(args) | |