Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.autograd import Variable | |
class MusicLSTM(nn.Module): | |
def __init__(self, input_size, hidden_size, output_size, model='lstm', num_layers=1, dropout_p=0): | |
super(MusicLSTM, self).__init__() | |
self.model = model | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.output_size = output_size | |
self.num_layers = num_layers | |
self.embeddings = nn.Embedding(input_size, hidden_size) | |
if self.model == 'lstm': | |
self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers) | |
elif self.model == 'gru': | |
self.rnn = nn.GRU(hidden_size, hidden_size, num_layers) | |
else: | |
raise NotImplementedError | |
self.out = nn.Linear(self.hidden_size, self.output_size) | |
self.drop = nn.Dropout(p=dropout_p) | |
def init_hidden(self, batch_size=1): | |
"""Initialize hidden states.""" | |
if self.model == 'lstm': | |
self.hidden = ( | |
torch.zeros(self.num_layers, batch_size, self.hidden_size), | |
torch.zeros(self.num_layers, batch_size, self.hidden_size) | |
) | |
elif self.model == 'gru': | |
self.hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size) | |
return self.hidden | |
def forward(self, x): | |
"""Forward pass.""" | |
# Ensure x is 2D (sequence length, batch size) | |
if x.dim() > 2: | |
x = x.squeeze() | |
batch_size = 1 if x.dim() == 1 else x.size(0) | |
x = x.long() | |
# Embed the input | |
embeds = self.embeddings(x) | |
# Initialize hidden state if not already done | |
if not hasattr(self, 'hidden'): | |
self.init_hidden(batch_size) | |
# Ensure embeds is 3D for RNN input (sequence length, batch size, embedding size) | |
if embeds.dim() == 2: | |
embeds = embeds.unsqueeze(1) | |
# RNN processing | |
rnn_out, self.hidden = self.rnn(embeds, self.hidden) | |
# Dropout and output layer | |
rnn_out = self.drop(rnn_out.squeeze(1)) | |
output = self.out(rnn_out) | |
return output | |