File size: 4,437 Bytes
60b6dc7 0fff8d2 d48e82c 0fff8d2 60b6dc7 0fff8d2 60b6dc7 0fff8d2 60b6dc7 0fff8d2 d48e82c 0fff8d2 d48e82c 0fff8d2 d48e82c 0fff8d2 60b6dc7 0fff8d2 60b6dc7 0fff8d2 60b6dc7 0fff8d2 60b6dc7 0fff8d2 60b6dc7 0fff8d2 60b6dc7 0fff8d2 60b6dc7 0fff8d2 60b6dc7 0fff8d2 60b6dc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from torch import nn
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel
# from huggingface_hub import notebook_login
# notebook_login()
# AutoEncoderConfig.register_for_auto_class()
# AutoEncoder.register_for_auto_class("AutoModel")
# AutoConfig.register("autoencoder", AutoEncoderConfig)
# AutoModel.register(AutoEncoderConfig, AutoModel)
# autoencoder.push_to_hub("autoencoder")
class AutoEncoderConfig(PretrainedConfig):
model_type = "autoencoder"
def __init__(
self,
input_dim=None,
latent_dim=None,
layer_types=None,
dropout_rate=None,
num_layers=None,
compression_rate=None,
bidirectional=False,
**kwargs
):
super().__init__(**kwargs)
self.input_dim = input_dim
self.latent_dim = latent_dim
self.layer_types = layer_types
self.dropout_rate = dropout_rate
self.num_layers = num_layers
self.compression_rate = compression_rate
self.bidirectional = bidirectional
def create_layers(model_section, layer_types, input_dim, latent_dim, num_layers, dropout_rate, compression_rate, bidirectional):
layers = []
current_dim = input_dim
input_diamensions = []
output_diamensions = []
for _ in range(num_layers):
input_diamensions.append(current_dim)
next_dim = max(int(current_dim * compression_rate), latent_dim)
current_dim = next_dim
output_diamensions.append(current_dim)
output_diamensions[num_layers - 1] = latent_dim
if model_section == "decoder":
input_diamensions, output_diamensions = output_diamensions, input_diamensions
input_diamensions.reverse()
output_diamensions.reverse()
if bidirectional & (layer_types in ['lstm', 'rnn', 'gru']):
output_diamensions = [2*value for value in output_diamensions]
for idx, (input_dim, output_dim) in enumerate(zip(input_diamensions, output_diamensions)):
if layer_types == 'linear':
layers.append(nn.Linear(input_dim, output_dim))
elif layer_types == 'lstm':
layers.append(nn.LSTM(input_dim, output_dim // (2 if bidirectional else 1), batch_first=True, bidirectional=bidirectional))
elif layer_types == 'rnn':
layers.append(nn.RNN(input_dim, output_dim // (2 if bidirectional else 1), batch_first=True, bidirectional=bidirectional))
elif layer_types == 'gru':
layers.append(nn.GRU(input_dim, output_dim // (2 if bidirectional else 1), batch_first=True, bidirectional=bidirectional))
if (idx != num_layers - 1) & (dropout_rate != None):
layers.append(nn.Dropout(dropout_rate))
return nn.Sequential(*layers)
class AutoEncoder(PreTrainedModel):
config_class = AutoEncoderConfig
def __init__(self, config):
super(AutoEncoder, self).__init__(config)
self.encoder = create_layers("encoder",
config.layer_types, config.input_dim, config.latent_dim,
config.num_layers, config.dropout_rate, config.compression_rate,
config.bidirectional,
)
# Assuming symmetry between encoder and decoder
self.decoder = create_layers("decoder",
config.layer_types, config.input_dim, config.latent_dim,
config.num_layers, config.dropout_rate, config.compression_rate,
config.bidirectional,
)
def forward(self, x):
if self.config.layer_types in ['lstm', 'rnn', 'gru']:
for layer in self.encoder:
print(layer)
if isinstance(layer, nn.LSTM):
x, (h_n, c_n)= layer(x)
elif isinstance(layer, nn.RNN):
x, h_o = layer(x)
elif isinstance(layer, nn.GRU):
x, h_o = layer(x)
else:
x = layer(x)
for layer in self.decoder:
if isinstance(layer, nn.LSTM):
x, (h_n, c_n) = layer(x)
elif isinstance(layer, nn.RNN):
x, h_o = layer(x)
elif isinstance(layer, nn.GRU):
x, h_o = layer(x)
else:
x = layer(x)
else:
x = self.encoder(x)
x = self.decoder(x)
return x |