Music-Generator / model.py
nullHawk's picture
v0
e8ca4ee verified
raw
history blame
2.2 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
class MusicLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size, model='lstm', num_layers=1, dropout_p=0):
super(MusicLSTM, self).__init__()
self.model = model
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.num_layers = num_layers
self.embeddings = nn.Embedding(input_size, hidden_size)
if self.model == 'lstm':
self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers)
elif self.model == 'gru':
self.rnn = nn.GRU(hidden_size, hidden_size, num_layers)
else:
raise NotImplementedError
self.out = nn.Linear(self.hidden_size, self.output_size)
self.drop = nn.Dropout(p=dropout_p)
def init_hidden(self, batch_size=1):
"""Initialize hidden states."""
if self.model == 'lstm':
self.hidden = (
torch.zeros(self.num_layers, batch_size, self.hidden_size),
torch.zeros(self.num_layers, batch_size, self.hidden_size)
)
elif self.model == 'gru':
self.hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size)
return self.hidden
def forward(self, x):
"""Forward pass."""
# Ensure x is 2D (sequence length, batch size)
if x.dim() > 2:
x = x.squeeze()
batch_size = 1 if x.dim() == 1 else x.size(0)
x = x.long()
# Embed the input
embeds = self.embeddings(x)
# Initialize hidden state if not already done
if not hasattr(self, 'hidden'):
self.init_hidden(batch_size)
# Ensure embeds is 3D for RNN input (sequence length, batch size, embedding size)
if embeds.dim() == 2:
embeds = embeds.unsqueeze(1)
# RNN processing
rnn_out, self.hidden = self.rnn(embeds, self.hidden)
# Dropout and output layer
rnn_out = self.drop(rnn_out.squeeze(1))
output = self.out(rnn_out)
return output