XavierJiezou's picture
Upload folder using huggingface_hub
3c8ff2e verified
import torch.nn as nn
import torch.nn.init as init
def weight_init(m, spread=1.0):
'''
Initializes a model's parameters.
Credits to: https://gist.github.com/jeasinema
Usage:
model = Model()
model.apply(weight_init)
'''
if isinstance(m, nn.Conv1d):
init.normal_(m.weight.data, mean=0, std=spread)
if m.bias is not None:
init.normal_(m.bias.data, mean=0, std=spread)
elif isinstance(m, nn.Conv2d):
init.xavier_normal_(m.weight.data, gain=spread)
if m.bias is not None:
init.normal_(m.bias.data, mean=0, std=spread)
elif isinstance(m, nn.Conv3d):
init.xavier_normal_(m.weight.data, gain=spread)
if m.bias is not None:
init.normal_(m.bias.data, mean=0, std=spread)
elif isinstance(m, nn.ConvTranspose1d):
init.normal_(m.weight.data, mean=0, std=spread)
if m.bias is not None:
init.normal_(m.bias.data, mean=0, std=spread)
elif isinstance(m, nn.ConvTranspose2d):
init.xavier_normal_(m.weight.data, gain=spread)
if m.bias is not None:
init.normal_(m.bias.data, mean=0, std=spread)
elif isinstance(m, nn.ConvTranspose3d):
init.xavier_normal_(m.weight.data, gain=spread)
if m.bias is not None:
init.normal_(m.bias.data, mean=0, std=spread)
elif isinstance(m, nn.BatchNorm1d):
init.normal_(m.weight.data, mean=0, std=spread)
init.constant_(m.bias.data, 0)
elif isinstance(m, nn.BatchNorm2d):
init.normal_(m.weight.data, mean=0, std=spread)
init.constant_(m.bias.data, 0)
elif isinstance(m, nn.BatchNorm3d):
init.normal_(m.weight.data, mean=0, std=spread)
init.constant_(m.bias.data, 0)
elif isinstance(m, nn.Linear):
init.xavier_normal_(m.weight.data, gain=spread)
try:
init.normal_(m.bias.data, mean=0, std=spread)
except AttributeError:
pass
elif isinstance(m, nn.LSTM):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data, mean=0, std=spread)
elif isinstance(m, nn.LSTMCell):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data, mean=0, std=spread)
elif isinstance(m, nn.GRU):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data, mean=0, std=spread)
elif isinstance(m, nn.GRUCell):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data, mean=0, std=spread)