File size: 2,198 Bytes
e8ca4ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

class MusicLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, model='lstm', num_layers=1, dropout_p=0):
        super(MusicLSTM, self).__init__()
        self.model = model
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers

        self.embeddings = nn.Embedding(input_size, hidden_size)
        if self.model == 'lstm':
            self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers)
        elif self.model == 'gru':
            self.rnn = nn.GRU(hidden_size, hidden_size, num_layers)
        else:
            raise NotImplementedError

        self.out = nn.Linear(self.hidden_size, self.output_size)
        self.drop = nn.Dropout(p=dropout_p)

    def init_hidden(self, batch_size=1):
        """Initialize hidden states."""
        if self.model == 'lstm':
            self.hidden = (
                torch.zeros(self.num_layers, batch_size, self.hidden_size),
                torch.zeros(self.num_layers, batch_size, self.hidden_size)
            )
        elif self.model == 'gru':
            self.hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size)

        return self.hidden

    def forward(self, x):
        """Forward pass."""
        # Ensure x is 2D (sequence length, batch size)
        if x.dim() > 2:
            x = x.squeeze()

        batch_size = 1 if x.dim() == 1 else x.size(0)
        x = x.long()

        # Embed the input
        embeds = self.embeddings(x)

        # Initialize hidden state if not already done
        if not hasattr(self, 'hidden'):
            self.init_hidden(batch_size)

        # Ensure embeds is 3D for RNN input (sequence length, batch size, embedding size)
        if embeds.dim() == 2:
            embeds = embeds.unsqueeze(1)

        # RNN processing
        rnn_out, self.hidden = self.rnn(embeds, self.hidden)

        # Dropout and output layer
        rnn_out = self.drop(rnn_out.squeeze(1))
        output = self.out(rnn_out)

        return output