nullHawk commited on
Commit
e8ca4ee
·
verified ·
1 Parent(s): 9d0c993
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__
2
+ output/*
3
+ temp.ipynb
4
+ output*
Loss_per_epoch.png ADDED
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import gradio as gr
5
+ from model import MusicLSTM
6
+ from train import DataLoader, Config, generate_song as generate_ABC_notation
7
+ from utils import load_vocab
8
+ from convert import abc_to_audio
9
+
10
+ class GradioApp():
11
+ def __init__(self):
12
+ # Set up configuration and data
13
+ self.config = Config()
14
+ self.CHECKPOINT_FILE = "checkpoint/model.pth"
15
+ self.data_loader = DataLoader(self.config.INPUT_FILE, self.config)
16
+ self.checkpoint = torch.load(self.CHECKPOINT_FILE, weights_only=False)
17
+ char_idx, char_list = load_vocab()
18
+ self.model = MusicLSTM(
19
+ input_size=len(char_idx),
20
+ hidden_size=self.config.HIDDEN_SIZE,
21
+ output_size=len(char_idx),
22
+ )
23
+ self.model.load_state_dict(self.checkpoint)
24
+ self.model.eval()
25
+
26
+ #Setup Interface
27
+ self.input = gr.Button("")
28
+ self.output = gr.Audio(label="Generated Music")
29
+ # self.output = gr.Textbox("")
30
+ self.interface = gr.Interface(fn=self.generate_music, inputs=self.input, outputs=self.output, title="AI Music Generator", description="Generate a new song using a trained RNN model.")
31
+
32
+ def launch(self):
33
+ self.interface.launch()
34
+
35
+ def generate_music(self, input):
36
+ """Generate a new song using the trained model."""
37
+ abc_notation = generate_ABC_notation(self.model, self.data_loader)
38
+ abc_notation = abc_notation.strip("<start>").strip("<end>").strip()
39
+ audio = abc_to_audio(abc_notation)
40
+ return audio
41
+
42
+ if __name__ == '__main__':
43
+ app = GradioApp()
44
+ app.launch()
45
+
46
+
app.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ libfluidsynth-dev
2
+ libsndfile1
3
+ abc2midi
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t0 ADDED
Binary file (844 kB). View file
 
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t20 ADDED
Binary file (844 kB). View file
 
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t40 ADDED
Binary file (844 kB). View file
 
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t60 ADDED
Binary file (844 kB). View file
 
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t80 ADDED
Binary file (844 kB). View file
 
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t99 ADDED
Binary file (844 kB). View file
 
checkpoint/model.pth ADDED
Binary file (839 kB). View file
 
config/vocab.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"char_idx": "<]=_4Xl)uq5CBw#d(~H}scntZ!hIF6p'\\E/g&?fTW{^-v9MA710+oK\tJS[\n,Q \"G2a:L|mxVNbPRk*jYyD3e.8Oi>Uzr@", "char_list": ["<", "]", "=", "_", "4", "X", "l", ")", "u", "q", "5", "C", "B", "w", "#", "d", "(", "~", "H", "}", "s", "c", "n", "t", "Z", "!", "h", "I", "F", "6", "p", "'", "\\", "E", "/", "g", "&", "?", "f", "T", "W", "{", "^", "-", "v", "9", "M", "A", "7", "1", "0", "+", "o", "K", "\t", "J", "S", "[", "\n", ",", "Q", " ", "\"", "G", "2", "a", ":", "L", "|", "m", "x", "V", "N", "b", "P", "R", "k", "*", "j", "Y", "y", "D", "3", "e", ".", "8", "O", "i", ">", "U", "z", "r", "@"]}
convert.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from music21 import converter, stream
2
+ from midi2audio import FluidSynth
3
+ import subprocess
4
+
5
+ def abc_to_audio(abc_notation, output_format='wav',sound_font="FluidR3_GM.sf2"):
6
+ """ Convert ABC notation to wav file. """
7
+ abc_file = 'output.abc'
8
+ with open(abc_file, 'w') as f:
9
+ f.write(abc_notation)
10
+ subprocess.run(['abc2midi', abc_file, '-o', "output.midi"])
11
+ fs = FluidSynth()
12
+ fs.midi_to_audio("output.midi", "output.wav")
13
+ return "output.wav"
14
+
15
+
16
+ if __name__ == '__main__':
17
+ abc_to_audio("""X:12
18
+ T:Byrne: Triop
19
+ C:Trad Figne
20
+ Z:id:hn-hornpipe-53
21
+ M:C|
22
+ K:G
23
+ (3DFB d2dc | def2 edef | e2a2 df | g4- gdBG | A4G | A4 :|
24
+ |: ae edc | edcB A2B2 | A2G2 | G6 d2 | e4^c4 | d4 d4 | ed e2 | d4 ||
25
+ P:variations:
26
+ |: ABA AGE|F2A d2A|d2g d2:|
27
+ a2f fef aba|a2f g2e fed|c2A GBd|f2g g2a|bgb aag|dcB B2G|A2G A2G:|
28
+ |:F2A A2G|AGE G2d||
29
+ P:variations
30
+ |: AGF GBd | cde d2B | c2c c2A :|
31
+ |: de fe | fdfe dFAd | A2AG A2f2 | g2ag e2B2 | A2AB ^cdce | d2d>c | B4z2 | B4 | A4G2 | ^F4G4 | G4 :|
32
+ |: G^F G2 | c4 ||
33
+ GBdB | c2 ded2 | c2B2c2 | d2c2B2 | c2d2 | c2B2 | A4 :|""")
data/.DS_Store ADDED
Binary file (30.7 kB). View file
 
data/music.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/pop.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/sample-music.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ X:3
2
+ T:Badine
3
+ O:France
4
+ A:Provence
5
+ Z:Transcrit et/ou corrig? par Michel BELLON - 2005-04-01
6
+ Z:Pour toute observation mailto:[email protected]
7
+ M:C|
8
+ L:1/8
9
+ Q: "Allegro"
10
+ K:Bb
11
+ V:1 name=G
12
+ d2 (cB) | d2 (cB) f2 ed | f4 g2 g2 | feed eddc | d2B2 d2cB | d2cB f2ed |
13
+ f4 g2g2 | fedc d2!+!c2 | B4 :: FBcB | B2AB cd ec | d2 B2 df dB |
14
+ cf cA Bd cB | B2 A2 f2 f2 | (ABcd) e2 d2 | e2 dc dcde | fefg fefg |
15
+ Te4 d2cB | d2cB f2ed | f4 g2g2 | feed eddc | d2B2 d2cB | d2cB f2ed |
16
+ f4 g2g2 | fedc d2!+!c2 | B4 !fine! :: [K:Bbm] c2de | d2ef B2cd |
17
+ c2F2 dc Bc | !+!=A2 B2 c2 d2 | d2 c2 d2ef | g2g2 c2de | f2f2 B=ABc | F2Bc d2c2 |
18
+ B4:| fefg | f2e=d e2f2 | {f2}g4 edef | e2dc d2e2 | {e2}f4 Bc=Ac |
19
+ B2F2 dece | d2c2 dcde | f6 ed |c4 !D.C.! |]
20
+ V:2 name=V
21
+ z4 | z4 d2cB | d2cB e4 | B4 f2f2 | b2B2 z4 | z4 d2cB | d2cB e4 | d2B2 f2F2 | B4 ::
22
+ z4 | f4f2f2 | B2b2 b2b2 | a2f2 g2=e2 | f2F2 z4 | f2f2 A2B2 | c2f2 B2B2 | B2b2 b2b2 |
23
+ a4 b2B2 | z4 d2cB | d2cB e4 | B4 f2f2 | b2B2 z4 | z4 d2cB | d2cB e4 | d2B2 f2F2 | B4 ::
24
+ [K:Bbm] b2b2 | b4 g2e2 | f2F2 B2B2 | e2d2 c2B2 | f4 b2b2 | e2fg a2a2 | d4 g2g2 | fede f2F2 | B4 :|
25
+ b2b2 | b3a g2f2 | e4 a2a2 | a3g f2e2 | d4 z4 | z4 (bc')(=ac') | b2f2 B2Bc | dcde d2e2 | f4 |]
model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.autograd import Variable
6
+
7
+ class MusicLSTM(nn.Module):
8
+ def __init__(self, input_size, hidden_size, output_size, model='lstm', num_layers=1, dropout_p=0):
9
+ super(MusicLSTM, self).__init__()
10
+ self.model = model
11
+ self.input_size = input_size
12
+ self.hidden_size = hidden_size
13
+ self.output_size = output_size
14
+ self.num_layers = num_layers
15
+
16
+ self.embeddings = nn.Embedding(input_size, hidden_size)
17
+ if self.model == 'lstm':
18
+ self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers)
19
+ elif self.model == 'gru':
20
+ self.rnn = nn.GRU(hidden_size, hidden_size, num_layers)
21
+ else:
22
+ raise NotImplementedError
23
+
24
+ self.out = nn.Linear(self.hidden_size, self.output_size)
25
+ self.drop = nn.Dropout(p=dropout_p)
26
+
27
+ def init_hidden(self, batch_size=1):
28
+ """Initialize hidden states."""
29
+ if self.model == 'lstm':
30
+ self.hidden = (
31
+ torch.zeros(self.num_layers, batch_size, self.hidden_size),
32
+ torch.zeros(self.num_layers, batch_size, self.hidden_size)
33
+ )
34
+ elif self.model == 'gru':
35
+ self.hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size)
36
+
37
+ return self.hidden
38
+
39
+ def forward(self, x):
40
+ """Forward pass."""
41
+ # Ensure x is 2D (sequence length, batch size)
42
+ if x.dim() > 2:
43
+ x = x.squeeze()
44
+
45
+ batch_size = 1 if x.dim() == 1 else x.size(0)
46
+ x = x.long()
47
+
48
+ # Embed the input
49
+ embeds = self.embeddings(x)
50
+
51
+ # Initialize hidden state if not already done
52
+ if not hasattr(self, 'hidden'):
53
+ self.init_hidden(batch_size)
54
+
55
+ # Ensure embeds is 3D for RNN input (sequence length, batch size, embedding size)
56
+ if embeds.dim() == 2:
57
+ embeds = embeds.unsqueeze(1)
58
+
59
+ # RNN processing
60
+ rnn_out, self.hidden = self.rnn(embeds, self.hidden)
61
+
62
+ # Dropout and output layer
63
+ rnn_out = self.drop(rnn_out.squeeze(1))
64
+ output = self.out(rnn_out)
65
+
66
+ return output
requirments.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ music21
2
+ midi2audio
3
+ pyfluidsynth
train.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import random
5
+ import json
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from model import MusicLSTM as MusicRNN
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ from torch.autograd import Variable
13
+ from utils import seq_to_tensor, load_vocab, save_vocab
14
+
15
+
16
+ def logger(active=True):
17
+ """Simple logging utility."""
18
+ def log(*args, **kwargs):
19
+ if active:
20
+ print(*args, **kwargs)
21
+ return log
22
+
23
+ # Configuration
24
+ class Config:
25
+ SAVE_EVERY = 20
26
+ SEQ_SIZE = 25
27
+ RANDOM_SEED = 11
28
+ VALIDATION_SIZE = 0.15
29
+ LR = 1e-3
30
+ N_EPOCHS = 100
31
+ NUM_LAYERS = 1
32
+ HIDDEN_SIZE = 150
33
+ DROPOUT_P = 0
34
+ MODEL_TYPE = 'lstm'
35
+ INPUT_FILE = 'data/music.txt'
36
+ RESUME = False
37
+ BATCH_SIZE = 1
38
+
39
+ # Utility functions
40
+ def tic():
41
+ """Start timer."""
42
+ return time.time()
43
+
44
+ def toc(start_time, msg=None):
45
+ """Calculate elapsed time."""
46
+ s = time.time() - start_time
47
+ m = int(s / 60)
48
+ if msg:
49
+ return f'{m}m {int(s - (m * 60))}s {msg}'
50
+ return f'{m}m {int(s - (m * 60))}s'
51
+
52
+ class DataLoader:
53
+ def __init__(self, input_file, config):
54
+ self.config = config
55
+ self.char_idx, self.char_list = self._load_chars(input_file)
56
+ self.data = self._load_data(input_file)
57
+ self.train_idxs, self.valid_idxs = self._split_data()
58
+ log = logger(True)
59
+ log(f"Total songs: {len(self.data)}")
60
+ log(f"Training songs: {len(self.train_idxs)}")
61
+ log(f"Validation songs: {len(self.valid_idxs)}")
62
+
63
+ def _load_chars(self, input_file):
64
+ """Load unique characters from the input file."""
65
+ with open(input_file, 'r') as f:
66
+ char_idx = ''.join(set(f.read()))
67
+ return char_idx, list(char_idx)
68
+
69
+ def _load_data(self, input_file):
70
+ """Load song data from input file."""
71
+ with open(input_file, "r") as f:
72
+ data, buffer = [], ''
73
+ for line in f:
74
+ if line == '<start>\n':
75
+ buffer += line
76
+ elif line == '<end>\n':
77
+ buffer += line
78
+ data.append(buffer)
79
+ buffer = ''
80
+ else:
81
+ buffer += line
82
+
83
+ # Filter songs shorter than sequence size
84
+ data = [song for song in data if len(song) > self.config.SEQ_SIZE + 10]
85
+ return data
86
+
87
+ def _split_data(self):
88
+ """Split data into training and validation sets."""
89
+ num_train = len(self.data)
90
+ indices = list(range(num_train))
91
+
92
+ np.random.seed(self.config.RANDOM_SEED)
93
+ np.random.shuffle(indices)
94
+
95
+ split_idx = int(np.floor(self.config.VALIDATION_SIZE * num_train))
96
+ train_idxs = indices[split_idx:]
97
+ valid_idxs = indices[:split_idx]
98
+
99
+ return train_idxs, valid_idxs
100
+
101
+ def rand_slice(self, data, slice_len=None):
102
+ """Get a random slice of data."""
103
+ if slice_len is None:
104
+ slice_len = self.config.SEQ_SIZE
105
+
106
+ d_len = len(data)
107
+ s_idx = random.randint(0, d_len - slice_len)
108
+ e_idx = s_idx + slice_len + 1
109
+ return data[s_idx:e_idx]
110
+
111
+ def seq_to_tensor(self, seq):
112
+ """Convert sequence to tensor."""
113
+ out = torch.zeros(len(seq)).long()
114
+ for i, c in enumerate(seq):
115
+ out[i] = self.char_idx.index(c)
116
+ return out
117
+
118
+ def song_to_seq_target(self, song):
119
+ """Convert a song to sequence and target."""
120
+ try:
121
+ a_slice = self.rand_slice(song)
122
+ seq = self.seq_to_tensor(a_slice[:-1])
123
+ target = self.seq_to_tensor(a_slice[1:])
124
+ return seq, target
125
+ except Exception as e:
126
+ print(f"Error in song_to_seq_target: {e}")
127
+ print(f"Song length: {len(song)}")
128
+ raise
129
+
130
+ def train_model(config, data_loader, model, optimizer, loss_function):
131
+ """Training loop for the model."""
132
+ log = logger(True)
133
+ time_since = tic()
134
+ losses, v_losses = [], []
135
+
136
+ for epoch in range(config.N_EPOCHS):
137
+ # Training phase
138
+ epoch_loss = 0
139
+ model.train()
140
+
141
+ for i, song_idx in enumerate(data_loader.train_idxs):
142
+ try:
143
+ seq, target = data_loader.song_to_seq_target(data_loader.data[song_idx])
144
+
145
+ # Reset hidden state and gradients
146
+ model.init_hidden()
147
+ optimizer.zero_grad()
148
+
149
+ # Forward pass
150
+ outputs = model(seq)
151
+ loss = loss_function(outputs, target)
152
+
153
+ # Backward pass and optimization
154
+ loss.backward()
155
+ optimizer.step()
156
+
157
+ epoch_loss += loss.item()
158
+
159
+ msg = f'\rTraining Epoch: {epoch}, {(i+1)/len(data_loader.train_idxs)*100:.2f}% iter: {i} Time: {toc(time_since)} Loss: {loss.item():.4f}'
160
+ sys.stdout.write(msg)
161
+ sys.stdout.flush()
162
+
163
+ except Exception as e:
164
+ log(f"Error processing song {song_idx}: {e}")
165
+ continue
166
+
167
+ print()
168
+ losses.append(epoch_loss / len(data_loader.train_idxs))
169
+
170
+ # Validation phase
171
+ model.eval()
172
+ val_loss = 0
173
+ with torch.no_grad():
174
+ for i, song_idx in enumerate(data_loader.valid_idxs):
175
+ try:
176
+ seq, target = data_loader.song_to_seq_target(data_loader.data[song_idx])
177
+
178
+ # Reset hidden state
179
+ model.init_hidden()
180
+
181
+ # Forward pass
182
+ outputs = model(seq)
183
+ loss = loss_function(outputs, target)
184
+
185
+ val_loss += loss.item()
186
+
187
+ msg = f'\rValidation Epoch: {epoch}, {(i+1)/len(data_loader.valid_idxs)*100:.2f}% iter: {i} Time: {toc(time_since)} Loss: {loss.item():.4f}'
188
+ sys.stdout.write(msg)
189
+ sys.stdout.flush()
190
+
191
+ except Exception as e:
192
+ log(f"Error processing validation song {song_idx}: {e}")
193
+ continue
194
+
195
+ print()
196
+ v_losses.append(val_loss / len(data_loader.valid_idxs))
197
+
198
+ # Checkpoint saving
199
+ if epoch % config.SAVE_EVERY == 0 or epoch == config.N_EPOCHS - 1:
200
+ log('=======> Saving..')
201
+ state = {
202
+ 'model': model.state_dict(),
203
+ 'optimizer': optimizer.state_dict(),
204
+ 'loss': losses[-1],
205
+ 'v_loss': v_losses[-1],
206
+ 'losses': losses,
207
+ 'v_losses': v_losses,
208
+ 'epoch': epoch,
209
+ }
210
+ os.makedirs('checkpoint', exist_ok=True)
211
+ torch.save(model, f'checkpoint/ckpt_mdl_{config.MODEL_TYPE}_ep_{config.N_EPOCHS}_hsize_{config.HIDDEN_SIZE}_dout_{config.DROPOUT_P}.t{epoch}')
212
+
213
+ return losses, v_losses
214
+
215
+ def plot_losses(losses, v_losses):
216
+ """Plot training and validation losses."""
217
+ plt.figure(figsize=(10, 5))
218
+ plt.plot(losses, label='Training Loss')
219
+ plt.plot(v_losses, label='Validation Loss')
220
+ plt.xlabel('Epoch')
221
+ plt.ylabel('Loss')
222
+ plt.title('Loss per Epoch')
223
+ plt.legend()
224
+ plt.show()
225
+
226
+ def generate_song(model, data_loader, prime_str='<start>', max_len=1000, temp=0.8):
227
+ """Generate a new song using the trained model."""
228
+ model.eval()
229
+ model.init_hidden()
230
+ creation = prime_str
231
+ char_idx, char_list = load_vocab()
232
+
233
+ # Build up hidden state
234
+ prime = seq_to_tensor(creation, char_idx)
235
+
236
+ with torch.no_grad():
237
+ for _ in range(len(prime)-1):
238
+ _ = model(prime[_:_+1])
239
+
240
+ # Generate rest of sequence
241
+ for _ in range(max_len):
242
+ last_char = prime[-1:]
243
+ out = model(last_char).squeeze()
244
+
245
+ out = torch.exp(out/temp)
246
+ dist = out / torch.sum(out)
247
+
248
+ # Sample from distribution
249
+ next_char_idx = torch.multinomial(dist, 1).item()
250
+ next_char = char_idx[next_char_idx]
251
+
252
+ creation += next_char
253
+ prime = torch.cat([prime, torch.tensor([next_char_idx])], dim=0)
254
+
255
+ if creation[-5:] == '<end>':
256
+ break
257
+
258
+ return creation
259
+
260
+ def main():
261
+ """Main execution function."""
262
+ # Set up configuration and data
263
+ global model, data_loader
264
+ config = Config()
265
+ data_loader = DataLoader(config.INPUT_FILE, config)
266
+
267
+ # Model setup
268
+ in_size = out_size = len(data_loader.char_idx)
269
+ model = MusicRNN(
270
+ in_size,
271
+ config.HIDDEN_SIZE,
272
+ out_size,
273
+ config.MODEL_TYPE,
274
+ config.NUM_LAYERS,
275
+ config.DROPOUT_P
276
+ )
277
+
278
+ # Optimizer and loss
279
+ optimizer = torch.optim.Adam(model.parameters(), lr=config.LR)
280
+ loss_function = nn.CrossEntropyLoss()
281
+
282
+ # Train the model
283
+ losses, v_losses = train_model(config, data_loader, model, optimizer, loss_function)
284
+
285
+ # Plot losses
286
+ plot_losses(losses, v_losses)
287
+ save_vocab(data_loader)
288
+
289
+ # Generate a song
290
+ generated_song = generate_song(model, data_loader)
291
+ print("Generated Song:", generated_song)
292
+
293
+ if __name__ == "__main__":
294
+ main()
utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import json
5
+ import torch
6
+
7
+ # Only do the function below if verbose
8
+ def logger(verbose):
9
+ def log(*msg):
10
+ if verbose: print(*msg)
11
+ return log
12
+
13
+
14
+ last_time = time.time()
15
+ begin_time = last_time
16
+
17
+ def progress_bar(current, total, msg=None):
18
+ global last_time, begin_time
19
+ if current == 0:
20
+ begin_time = time.time() # Reset for new bar.
21
+ cur_time = time.time()
22
+ step_time = cur_time - last_time
23
+ last_time = cur_time
24
+ tot_time = cur_time - begin_time
25
+
26
+ L = []
27
+ L.append(' Step: %s' % format_time(step_time))
28
+ L.append(' | Tot: %s' % format_time(tot_time))
29
+ if msg:
30
+ L.append(' | ' + msg)
31
+ msg = ''.join(L)
32
+ sys.stdout.write(msg)
33
+ sys.stdout.write('\r')
34
+ #if current < total-1:
35
+ #
36
+ #else:
37
+ #sys.stdout.write('\n')
38
+ sys.stdout.flush()
39
+
40
+ def format_time(seconds):
41
+ days = int(seconds / 3600/24)
42
+ seconds = seconds - days*3600*24
43
+ hours = int(seconds / 3600)
44
+ seconds = seconds - hours*3600
45
+ minutes = int(seconds / 60)
46
+ seconds = seconds - minutes*60
47
+ secondsf = int(seconds)
48
+ seconds = seconds - secondsf
49
+ millis = int(seconds*1000)
50
+
51
+ f = ''
52
+ i = 1
53
+ if days > 0:
54
+ f += str(days) + 'D'
55
+ i += 1
56
+ if hours > 0 and i <= 2:
57
+ f += str(hours) + 'h'
58
+ i += 1
59
+ if minutes > 0 and i <= 2:
60
+ f += str(minutes) + 'm'
61
+ i += 1
62
+ if secondsf > 0 and i <= 2:
63
+ f += str(secondsf) + 's'
64
+ i += 1
65
+ if millis > 0 and i <= 2:
66
+ f += str(millis) + 'ms'
67
+ i += 1
68
+ if f == '':
69
+ f = '0ms'
70
+ return f
71
+
72
+ def save_vocab(data_loader, vocab_filename="config/vocab.json"):
73
+ """Save vocabulary to a JSON file."""
74
+ vocab = {
75
+ 'char_idx': data_loader.char_idx,
76
+ 'char_list': data_loader.char_list
77
+ }
78
+ with open(vocab_filename, 'w') as f:
79
+ json.dump(vocab, f)
80
+
81
+ def load_vocab(vocab_filename='config/vocab.json'):
82
+ with open(vocab_filename, 'r') as f:
83
+ vocab = json.load(f)
84
+ return vocab['char_idx'], vocab['char_list']
85
+
86
+ def seq_to_tensor(seq, char_idx):
87
+ """Convert sequence to tensor."""
88
+ out = torch.zeros(len(seq)).long()
89
+ for i, c in enumerate(seq):
90
+ out[i] = char_idx.index(c)
91
+ return out