Spaces:
Sleeping
Sleeping
import sys | |
import copy | |
import torch | |
def _check_model_old_version(model): | |
if hasattr(model.WN[0], 'res_layers') or hasattr(model.WN[0], 'cond_layers'): | |
return True | |
else: | |
return False | |
def _update_model_res_skip(old_model, new_model): | |
for idx in range(0, len(new_model.WN)): | |
wavenet = new_model.WN[idx] | |
n_channels = wavenet.n_channels | |
n_layers = wavenet.n_layers | |
wavenet.res_skip_layers = torch.nn.ModuleList() | |
for i in range(0, n_layers): | |
if i < n_layers - 1: | |
res_skip_channels = 2*n_channels | |
else: | |
res_skip_channels = n_channels | |
res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) | |
skip_layer = torch.nn.utils.remove_weight_norm(wavenet.skip_layers[i]) | |
if i < n_layers - 1: | |
res_layer = torch.nn.utils.remove_weight_norm(wavenet.res_layers[i]) | |
res_skip_layer.weight = torch.nn.Parameter(torch.cat([res_layer.weight, skip_layer.weight])) | |
res_skip_layer.bias = torch.nn.Parameter(torch.cat([res_layer.bias, skip_layer.bias])) | |
else: | |
res_skip_layer.weight = torch.nn.Parameter(skip_layer.weight) | |
res_skip_layer.bias = torch.nn.Parameter(skip_layer.bias) | |
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') | |
wavenet.res_skip_layers.append(res_skip_layer) | |
del wavenet.res_layers | |
del wavenet.skip_layers | |
def _update_model_cond(old_model, new_model): | |
for idx in range(0, len(new_model.WN)): | |
wavenet = new_model.WN[idx] | |
n_channels = wavenet.n_channels | |
n_layers = wavenet.n_layers | |
n_mel_channels = wavenet.cond_layers[0].weight.shape[1] | |
cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1) | |
cond_layer_weight = [] | |
cond_layer_bias = [] | |
for i in range(0, n_layers): | |
_cond_layer = torch.nn.utils.remove_weight_norm(wavenet.cond_layers[i]) | |
cond_layer_weight.append(_cond_layer.weight) | |
cond_layer_bias.append(_cond_layer.bias) | |
cond_layer.weight = torch.nn.Parameter(torch.cat(cond_layer_weight)) | |
cond_layer.bias = torch.nn.Parameter(torch.cat(cond_layer_bias)) | |
cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') | |
wavenet.cond_layer = cond_layer | |
del wavenet.cond_layers | |
def update_model(old_model): | |
if not _check_model_old_version(old_model): | |
return old_model | |
new_model = copy.deepcopy(old_model) | |
if hasattr(old_model.WN[0], 'res_layers'): | |
_update_model_res_skip(old_model, new_model) | |
if hasattr(old_model.WN[0], 'cond_layers'): | |
_update_model_cond(old_model, new_model) | |
for m in new_model.modules(): | |
if 'Conv' in str(type(m)) and not hasattr(m, 'padding_mode'): | |
setattr(m, 'padding_mode', 'zeros') | |
return new_model | |
if __name__ == '__main__': | |
old_model_path = sys.argv[1] | |
new_model_path = sys.argv[2] | |
model = torch.load(old_model_path, map_location='cpu') | |
model['model'] = update_model(model['model']) | |
torch.save(model, new_model_path) | |