Spaces:
Running
Running
File size: 2,198 Bytes
e8ca4ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
class MusicLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size, model='lstm', num_layers=1, dropout_p=0):
super(MusicLSTM, self).__init__()
self.model = model
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.num_layers = num_layers
self.embeddings = nn.Embedding(input_size, hidden_size)
if self.model == 'lstm':
self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers)
elif self.model == 'gru':
self.rnn = nn.GRU(hidden_size, hidden_size, num_layers)
else:
raise NotImplementedError
self.out = nn.Linear(self.hidden_size, self.output_size)
self.drop = nn.Dropout(p=dropout_p)
def init_hidden(self, batch_size=1):
"""Initialize hidden states."""
if self.model == 'lstm':
self.hidden = (
torch.zeros(self.num_layers, batch_size, self.hidden_size),
torch.zeros(self.num_layers, batch_size, self.hidden_size)
)
elif self.model == 'gru':
self.hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size)
return self.hidden
def forward(self, x):
"""Forward pass."""
# Ensure x is 2D (sequence length, batch size)
if x.dim() > 2:
x = x.squeeze()
batch_size = 1 if x.dim() == 1 else x.size(0)
x = x.long()
# Embed the input
embeds = self.embeddings(x)
# Initialize hidden state if not already done
if not hasattr(self, 'hidden'):
self.init_hidden(batch_size)
# Ensure embeds is 3D for RNN input (sequence length, batch size, embedding size)
if embeds.dim() == 2:
embeds = embeds.unsqueeze(1)
# RNN processing
rnn_out, self.hidden = self.rnn(embeds, self.hidden)
# Dropout and output layer
rnn_out = self.drop(rnn_out.squeeze(1))
output = self.out(rnn_out)
return output
|