Spaces:
Running
Running
v0
Browse files- .gitignore +4 -0
- Loss_per_epoch.png +0 -0
- app.py +46 -0
- app.txt +3 -0
- checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t0 +0 -0
- checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t20 +0 -0
- checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t40 +0 -0
- checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t60 +0 -0
- checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t80 +0 -0
- checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t99 +0 -0
- checkpoint/model.pth +0 -0
- config/vocab.json +1 -0
- convert.py +33 -0
- data/.DS_Store +0 -0
- data/music.txt +0 -0
- data/pop.txt +0 -0
- data/sample-music.txt +25 -0
- model.py +66 -0
- requirments.txt +3 -0
- train.py +294 -0
- utils.py +91 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
output/*
|
3 |
+
temp.ipynb
|
4 |
+
output*
|
Loss_per_epoch.png
ADDED
![]() |
app.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
from model import MusicLSTM
|
6 |
+
from train import DataLoader, Config, generate_song as generate_ABC_notation
|
7 |
+
from utils import load_vocab
|
8 |
+
from convert import abc_to_audio
|
9 |
+
|
10 |
+
class GradioApp():
|
11 |
+
def __init__(self):
|
12 |
+
# Set up configuration and data
|
13 |
+
self.config = Config()
|
14 |
+
self.CHECKPOINT_FILE = "checkpoint/model.pth"
|
15 |
+
self.data_loader = DataLoader(self.config.INPUT_FILE, self.config)
|
16 |
+
self.checkpoint = torch.load(self.CHECKPOINT_FILE, weights_only=False)
|
17 |
+
char_idx, char_list = load_vocab()
|
18 |
+
self.model = MusicLSTM(
|
19 |
+
input_size=len(char_idx),
|
20 |
+
hidden_size=self.config.HIDDEN_SIZE,
|
21 |
+
output_size=len(char_idx),
|
22 |
+
)
|
23 |
+
self.model.load_state_dict(self.checkpoint)
|
24 |
+
self.model.eval()
|
25 |
+
|
26 |
+
#Setup Interface
|
27 |
+
self.input = gr.Button("")
|
28 |
+
self.output = gr.Audio(label="Generated Music")
|
29 |
+
# self.output = gr.Textbox("")
|
30 |
+
self.interface = gr.Interface(fn=self.generate_music, inputs=self.input, outputs=self.output, title="AI Music Generator", description="Generate a new song using a trained RNN model.")
|
31 |
+
|
32 |
+
def launch(self):
|
33 |
+
self.interface.launch()
|
34 |
+
|
35 |
+
def generate_music(self, input):
|
36 |
+
"""Generate a new song using the trained model."""
|
37 |
+
abc_notation = generate_ABC_notation(self.model, self.data_loader)
|
38 |
+
abc_notation = abc_notation.strip("<start>").strip("<end>").strip()
|
39 |
+
audio = abc_to_audio(abc_notation)
|
40 |
+
return audio
|
41 |
+
|
42 |
+
if __name__ == '__main__':
|
43 |
+
app = GradioApp()
|
44 |
+
app.launch()
|
45 |
+
|
46 |
+
|
app.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
libfluidsynth-dev
|
2 |
+
libsndfile1
|
3 |
+
abc2midi
|
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t0
ADDED
Binary file (844 kB). View file
|
|
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t20
ADDED
Binary file (844 kB). View file
|
|
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t40
ADDED
Binary file (844 kB). View file
|
|
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t60
ADDED
Binary file (844 kB). View file
|
|
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t80
ADDED
Binary file (844 kB). View file
|
|
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t99
ADDED
Binary file (844 kB). View file
|
|
checkpoint/model.pth
ADDED
Binary file (839 kB). View file
|
|
config/vocab.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"char_idx": "<]=_4Xl)uq5CBw#d(~H}scntZ!hIF6p'\\E/g&?fTW{^-v9MA710+oK\tJS[\n,Q \"G2a:L|mxVNbPRk*jYyD3e.8Oi>Uzr@", "char_list": ["<", "]", "=", "_", "4", "X", "l", ")", "u", "q", "5", "C", "B", "w", "#", "d", "(", "~", "H", "}", "s", "c", "n", "t", "Z", "!", "h", "I", "F", "6", "p", "'", "\\", "E", "/", "g", "&", "?", "f", "T", "W", "{", "^", "-", "v", "9", "M", "A", "7", "1", "0", "+", "o", "K", "\t", "J", "S", "[", "\n", ",", "Q", " ", "\"", "G", "2", "a", ":", "L", "|", "m", "x", "V", "N", "b", "P", "R", "k", "*", "j", "Y", "y", "D", "3", "e", ".", "8", "O", "i", ">", "U", "z", "r", "@"]}
|
convert.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from music21 import converter, stream
|
2 |
+
from midi2audio import FluidSynth
|
3 |
+
import subprocess
|
4 |
+
|
5 |
+
def abc_to_audio(abc_notation, output_format='wav',sound_font="FluidR3_GM.sf2"):
|
6 |
+
""" Convert ABC notation to wav file. """
|
7 |
+
abc_file = 'output.abc'
|
8 |
+
with open(abc_file, 'w') as f:
|
9 |
+
f.write(abc_notation)
|
10 |
+
subprocess.run(['abc2midi', abc_file, '-o', "output.midi"])
|
11 |
+
fs = FluidSynth()
|
12 |
+
fs.midi_to_audio("output.midi", "output.wav")
|
13 |
+
return "output.wav"
|
14 |
+
|
15 |
+
|
16 |
+
if __name__ == '__main__':
|
17 |
+
abc_to_audio("""X:12
|
18 |
+
T:Byrne: Triop
|
19 |
+
C:Trad Figne
|
20 |
+
Z:id:hn-hornpipe-53
|
21 |
+
M:C|
|
22 |
+
K:G
|
23 |
+
(3DFB d2dc | def2 edef | e2a2 df | g4- gdBG | A4G | A4 :|
|
24 |
+
|: ae edc | edcB A2B2 | A2G2 | G6 d2 | e4^c4 | d4 d4 | ed e2 | d4 ||
|
25 |
+
P:variations:
|
26 |
+
|: ABA AGE|F2A d2A|d2g d2:|
|
27 |
+
a2f fef aba|a2f g2e fed|c2A GBd|f2g g2a|bgb aag|dcB B2G|A2G A2G:|
|
28 |
+
|:F2A A2G|AGE G2d||
|
29 |
+
P:variations
|
30 |
+
|: AGF GBd | cde d2B | c2c c2A :|
|
31 |
+
|: de fe | fdfe dFAd | A2AG A2f2 | g2ag e2B2 | A2AB ^cdce | d2d>c | B4z2 | B4 | A4G2 | ^F4G4 | G4 :|
|
32 |
+
|: G^F G2 | c4 ||
|
33 |
+
GBdB | c2 ded2 | c2B2c2 | d2c2B2 | c2d2 | c2B2 | A4 :|""")
|
data/.DS_Store
ADDED
Binary file (30.7 kB). View file
|
|
data/music.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/pop.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/sample-music.txt
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
X:3
|
2 |
+
T:Badine
|
3 |
+
O:France
|
4 |
+
A:Provence
|
5 |
+
Z:Transcrit et/ou corrig? par Michel BELLON - 2005-04-01
|
6 |
+
Z:Pour toute observation mailto:[email protected]
|
7 |
+
M:C|
|
8 |
+
L:1/8
|
9 |
+
Q: "Allegro"
|
10 |
+
K:Bb
|
11 |
+
V:1 name=G
|
12 |
+
d2 (cB) | d2 (cB) f2 ed | f4 g2 g2 | feed eddc | d2B2 d2cB | d2cB f2ed |
|
13 |
+
f4 g2g2 | fedc d2!+!c2 | B4 :: FBcB | B2AB cd ec | d2 B2 df dB |
|
14 |
+
cf cA Bd cB | B2 A2 f2 f2 | (ABcd) e2 d2 | e2 dc dcde | fefg fefg |
|
15 |
+
Te4 d2cB | d2cB f2ed | f4 g2g2 | feed eddc | d2B2 d2cB | d2cB f2ed |
|
16 |
+
f4 g2g2 | fedc d2!+!c2 | B4 !fine! :: [K:Bbm] c2de | d2ef B2cd |
|
17 |
+
c2F2 dc Bc | !+!=A2 B2 c2 d2 | d2 c2 d2ef | g2g2 c2de | f2f2 B=ABc | F2Bc d2c2 |
|
18 |
+
B4:| fefg | f2e=d e2f2 | {f2}g4 edef | e2dc d2e2 | {e2}f4 Bc=Ac |
|
19 |
+
B2F2 dece | d2c2 dcde | f6 ed |c4 !D.C.! |]
|
20 |
+
V:2 name=V
|
21 |
+
z4 | z4 d2cB | d2cB e4 | B4 f2f2 | b2B2 z4 | z4 d2cB | d2cB e4 | d2B2 f2F2 | B4 ::
|
22 |
+
z4 | f4f2f2 | B2b2 b2b2 | a2f2 g2=e2 | f2F2 z4 | f2f2 A2B2 | c2f2 B2B2 | B2b2 b2b2 |
|
23 |
+
a4 b2B2 | z4 d2cB | d2cB e4 | B4 f2f2 | b2B2 z4 | z4 d2cB | d2cB e4 | d2B2 f2F2 | B4 ::
|
24 |
+
[K:Bbm] b2b2 | b4 g2e2 | f2F2 B2B2 | e2d2 c2B2 | f4 b2b2 | e2fg a2a2 | d4 g2g2 | fede f2F2 | B4 :|
|
25 |
+
b2b2 | b3a g2f2 | e4 a2a2 | a3g f2e2 | d4 z4 | z4 (bc')(=ac') | b2f2 B2Bc | dcde d2e2 | f4 |]
|
model.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.optim as optim
|
5 |
+
from torch.autograd import Variable
|
6 |
+
|
7 |
+
class MusicLSTM(nn.Module):
|
8 |
+
def __init__(self, input_size, hidden_size, output_size, model='lstm', num_layers=1, dropout_p=0):
|
9 |
+
super(MusicLSTM, self).__init__()
|
10 |
+
self.model = model
|
11 |
+
self.input_size = input_size
|
12 |
+
self.hidden_size = hidden_size
|
13 |
+
self.output_size = output_size
|
14 |
+
self.num_layers = num_layers
|
15 |
+
|
16 |
+
self.embeddings = nn.Embedding(input_size, hidden_size)
|
17 |
+
if self.model == 'lstm':
|
18 |
+
self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers)
|
19 |
+
elif self.model == 'gru':
|
20 |
+
self.rnn = nn.GRU(hidden_size, hidden_size, num_layers)
|
21 |
+
else:
|
22 |
+
raise NotImplementedError
|
23 |
+
|
24 |
+
self.out = nn.Linear(self.hidden_size, self.output_size)
|
25 |
+
self.drop = nn.Dropout(p=dropout_p)
|
26 |
+
|
27 |
+
def init_hidden(self, batch_size=1):
|
28 |
+
"""Initialize hidden states."""
|
29 |
+
if self.model == 'lstm':
|
30 |
+
self.hidden = (
|
31 |
+
torch.zeros(self.num_layers, batch_size, self.hidden_size),
|
32 |
+
torch.zeros(self.num_layers, batch_size, self.hidden_size)
|
33 |
+
)
|
34 |
+
elif self.model == 'gru':
|
35 |
+
self.hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size)
|
36 |
+
|
37 |
+
return self.hidden
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
"""Forward pass."""
|
41 |
+
# Ensure x is 2D (sequence length, batch size)
|
42 |
+
if x.dim() > 2:
|
43 |
+
x = x.squeeze()
|
44 |
+
|
45 |
+
batch_size = 1 if x.dim() == 1 else x.size(0)
|
46 |
+
x = x.long()
|
47 |
+
|
48 |
+
# Embed the input
|
49 |
+
embeds = self.embeddings(x)
|
50 |
+
|
51 |
+
# Initialize hidden state if not already done
|
52 |
+
if not hasattr(self, 'hidden'):
|
53 |
+
self.init_hidden(batch_size)
|
54 |
+
|
55 |
+
# Ensure embeds is 3D for RNN input (sequence length, batch size, embedding size)
|
56 |
+
if embeds.dim() == 2:
|
57 |
+
embeds = embeds.unsqueeze(1)
|
58 |
+
|
59 |
+
# RNN processing
|
60 |
+
rnn_out, self.hidden = self.rnn(embeds, self.hidden)
|
61 |
+
|
62 |
+
# Dropout and output layer
|
63 |
+
rnn_out = self.drop(rnn_out.squeeze(1))
|
64 |
+
output = self.out(rnn_out)
|
65 |
+
|
66 |
+
return output
|
requirments.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
music21
|
2 |
+
midi2audio
|
3 |
+
pyfluidsynth
|
train.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import random
|
5 |
+
import json
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from model import MusicLSTM as MusicRNN
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.optim as optim
|
12 |
+
from torch.autograd import Variable
|
13 |
+
from utils import seq_to_tensor, load_vocab, save_vocab
|
14 |
+
|
15 |
+
|
16 |
+
def logger(active=True):
|
17 |
+
"""Simple logging utility."""
|
18 |
+
def log(*args, **kwargs):
|
19 |
+
if active:
|
20 |
+
print(*args, **kwargs)
|
21 |
+
return log
|
22 |
+
|
23 |
+
# Configuration
|
24 |
+
class Config:
|
25 |
+
SAVE_EVERY = 20
|
26 |
+
SEQ_SIZE = 25
|
27 |
+
RANDOM_SEED = 11
|
28 |
+
VALIDATION_SIZE = 0.15
|
29 |
+
LR = 1e-3
|
30 |
+
N_EPOCHS = 100
|
31 |
+
NUM_LAYERS = 1
|
32 |
+
HIDDEN_SIZE = 150
|
33 |
+
DROPOUT_P = 0
|
34 |
+
MODEL_TYPE = 'lstm'
|
35 |
+
INPUT_FILE = 'data/music.txt'
|
36 |
+
RESUME = False
|
37 |
+
BATCH_SIZE = 1
|
38 |
+
|
39 |
+
# Utility functions
|
40 |
+
def tic():
|
41 |
+
"""Start timer."""
|
42 |
+
return time.time()
|
43 |
+
|
44 |
+
def toc(start_time, msg=None):
|
45 |
+
"""Calculate elapsed time."""
|
46 |
+
s = time.time() - start_time
|
47 |
+
m = int(s / 60)
|
48 |
+
if msg:
|
49 |
+
return f'{m}m {int(s - (m * 60))}s {msg}'
|
50 |
+
return f'{m}m {int(s - (m * 60))}s'
|
51 |
+
|
52 |
+
class DataLoader:
|
53 |
+
def __init__(self, input_file, config):
|
54 |
+
self.config = config
|
55 |
+
self.char_idx, self.char_list = self._load_chars(input_file)
|
56 |
+
self.data = self._load_data(input_file)
|
57 |
+
self.train_idxs, self.valid_idxs = self._split_data()
|
58 |
+
log = logger(True)
|
59 |
+
log(f"Total songs: {len(self.data)}")
|
60 |
+
log(f"Training songs: {len(self.train_idxs)}")
|
61 |
+
log(f"Validation songs: {len(self.valid_idxs)}")
|
62 |
+
|
63 |
+
def _load_chars(self, input_file):
|
64 |
+
"""Load unique characters from the input file."""
|
65 |
+
with open(input_file, 'r') as f:
|
66 |
+
char_idx = ''.join(set(f.read()))
|
67 |
+
return char_idx, list(char_idx)
|
68 |
+
|
69 |
+
def _load_data(self, input_file):
|
70 |
+
"""Load song data from input file."""
|
71 |
+
with open(input_file, "r") as f:
|
72 |
+
data, buffer = [], ''
|
73 |
+
for line in f:
|
74 |
+
if line == '<start>\n':
|
75 |
+
buffer += line
|
76 |
+
elif line == '<end>\n':
|
77 |
+
buffer += line
|
78 |
+
data.append(buffer)
|
79 |
+
buffer = ''
|
80 |
+
else:
|
81 |
+
buffer += line
|
82 |
+
|
83 |
+
# Filter songs shorter than sequence size
|
84 |
+
data = [song for song in data if len(song) > self.config.SEQ_SIZE + 10]
|
85 |
+
return data
|
86 |
+
|
87 |
+
def _split_data(self):
|
88 |
+
"""Split data into training and validation sets."""
|
89 |
+
num_train = len(self.data)
|
90 |
+
indices = list(range(num_train))
|
91 |
+
|
92 |
+
np.random.seed(self.config.RANDOM_SEED)
|
93 |
+
np.random.shuffle(indices)
|
94 |
+
|
95 |
+
split_idx = int(np.floor(self.config.VALIDATION_SIZE * num_train))
|
96 |
+
train_idxs = indices[split_idx:]
|
97 |
+
valid_idxs = indices[:split_idx]
|
98 |
+
|
99 |
+
return train_idxs, valid_idxs
|
100 |
+
|
101 |
+
def rand_slice(self, data, slice_len=None):
|
102 |
+
"""Get a random slice of data."""
|
103 |
+
if slice_len is None:
|
104 |
+
slice_len = self.config.SEQ_SIZE
|
105 |
+
|
106 |
+
d_len = len(data)
|
107 |
+
s_idx = random.randint(0, d_len - slice_len)
|
108 |
+
e_idx = s_idx + slice_len + 1
|
109 |
+
return data[s_idx:e_idx]
|
110 |
+
|
111 |
+
def seq_to_tensor(self, seq):
|
112 |
+
"""Convert sequence to tensor."""
|
113 |
+
out = torch.zeros(len(seq)).long()
|
114 |
+
for i, c in enumerate(seq):
|
115 |
+
out[i] = self.char_idx.index(c)
|
116 |
+
return out
|
117 |
+
|
118 |
+
def song_to_seq_target(self, song):
|
119 |
+
"""Convert a song to sequence and target."""
|
120 |
+
try:
|
121 |
+
a_slice = self.rand_slice(song)
|
122 |
+
seq = self.seq_to_tensor(a_slice[:-1])
|
123 |
+
target = self.seq_to_tensor(a_slice[1:])
|
124 |
+
return seq, target
|
125 |
+
except Exception as e:
|
126 |
+
print(f"Error in song_to_seq_target: {e}")
|
127 |
+
print(f"Song length: {len(song)}")
|
128 |
+
raise
|
129 |
+
|
130 |
+
def train_model(config, data_loader, model, optimizer, loss_function):
|
131 |
+
"""Training loop for the model."""
|
132 |
+
log = logger(True)
|
133 |
+
time_since = tic()
|
134 |
+
losses, v_losses = [], []
|
135 |
+
|
136 |
+
for epoch in range(config.N_EPOCHS):
|
137 |
+
# Training phase
|
138 |
+
epoch_loss = 0
|
139 |
+
model.train()
|
140 |
+
|
141 |
+
for i, song_idx in enumerate(data_loader.train_idxs):
|
142 |
+
try:
|
143 |
+
seq, target = data_loader.song_to_seq_target(data_loader.data[song_idx])
|
144 |
+
|
145 |
+
# Reset hidden state and gradients
|
146 |
+
model.init_hidden()
|
147 |
+
optimizer.zero_grad()
|
148 |
+
|
149 |
+
# Forward pass
|
150 |
+
outputs = model(seq)
|
151 |
+
loss = loss_function(outputs, target)
|
152 |
+
|
153 |
+
# Backward pass and optimization
|
154 |
+
loss.backward()
|
155 |
+
optimizer.step()
|
156 |
+
|
157 |
+
epoch_loss += loss.item()
|
158 |
+
|
159 |
+
msg = f'\rTraining Epoch: {epoch}, {(i+1)/len(data_loader.train_idxs)*100:.2f}% iter: {i} Time: {toc(time_since)} Loss: {loss.item():.4f}'
|
160 |
+
sys.stdout.write(msg)
|
161 |
+
sys.stdout.flush()
|
162 |
+
|
163 |
+
except Exception as e:
|
164 |
+
log(f"Error processing song {song_idx}: {e}")
|
165 |
+
continue
|
166 |
+
|
167 |
+
print()
|
168 |
+
losses.append(epoch_loss / len(data_loader.train_idxs))
|
169 |
+
|
170 |
+
# Validation phase
|
171 |
+
model.eval()
|
172 |
+
val_loss = 0
|
173 |
+
with torch.no_grad():
|
174 |
+
for i, song_idx in enumerate(data_loader.valid_idxs):
|
175 |
+
try:
|
176 |
+
seq, target = data_loader.song_to_seq_target(data_loader.data[song_idx])
|
177 |
+
|
178 |
+
# Reset hidden state
|
179 |
+
model.init_hidden()
|
180 |
+
|
181 |
+
# Forward pass
|
182 |
+
outputs = model(seq)
|
183 |
+
loss = loss_function(outputs, target)
|
184 |
+
|
185 |
+
val_loss += loss.item()
|
186 |
+
|
187 |
+
msg = f'\rValidation Epoch: {epoch}, {(i+1)/len(data_loader.valid_idxs)*100:.2f}% iter: {i} Time: {toc(time_since)} Loss: {loss.item():.4f}'
|
188 |
+
sys.stdout.write(msg)
|
189 |
+
sys.stdout.flush()
|
190 |
+
|
191 |
+
except Exception as e:
|
192 |
+
log(f"Error processing validation song {song_idx}: {e}")
|
193 |
+
continue
|
194 |
+
|
195 |
+
print()
|
196 |
+
v_losses.append(val_loss / len(data_loader.valid_idxs))
|
197 |
+
|
198 |
+
# Checkpoint saving
|
199 |
+
if epoch % config.SAVE_EVERY == 0 or epoch == config.N_EPOCHS - 1:
|
200 |
+
log('=======> Saving..')
|
201 |
+
state = {
|
202 |
+
'model': model.state_dict(),
|
203 |
+
'optimizer': optimizer.state_dict(),
|
204 |
+
'loss': losses[-1],
|
205 |
+
'v_loss': v_losses[-1],
|
206 |
+
'losses': losses,
|
207 |
+
'v_losses': v_losses,
|
208 |
+
'epoch': epoch,
|
209 |
+
}
|
210 |
+
os.makedirs('checkpoint', exist_ok=True)
|
211 |
+
torch.save(model, f'checkpoint/ckpt_mdl_{config.MODEL_TYPE}_ep_{config.N_EPOCHS}_hsize_{config.HIDDEN_SIZE}_dout_{config.DROPOUT_P}.t{epoch}')
|
212 |
+
|
213 |
+
return losses, v_losses
|
214 |
+
|
215 |
+
def plot_losses(losses, v_losses):
|
216 |
+
"""Plot training and validation losses."""
|
217 |
+
plt.figure(figsize=(10, 5))
|
218 |
+
plt.plot(losses, label='Training Loss')
|
219 |
+
plt.plot(v_losses, label='Validation Loss')
|
220 |
+
plt.xlabel('Epoch')
|
221 |
+
plt.ylabel('Loss')
|
222 |
+
plt.title('Loss per Epoch')
|
223 |
+
plt.legend()
|
224 |
+
plt.show()
|
225 |
+
|
226 |
+
def generate_song(model, data_loader, prime_str='<start>', max_len=1000, temp=0.8):
|
227 |
+
"""Generate a new song using the trained model."""
|
228 |
+
model.eval()
|
229 |
+
model.init_hidden()
|
230 |
+
creation = prime_str
|
231 |
+
char_idx, char_list = load_vocab()
|
232 |
+
|
233 |
+
# Build up hidden state
|
234 |
+
prime = seq_to_tensor(creation, char_idx)
|
235 |
+
|
236 |
+
with torch.no_grad():
|
237 |
+
for _ in range(len(prime)-1):
|
238 |
+
_ = model(prime[_:_+1])
|
239 |
+
|
240 |
+
# Generate rest of sequence
|
241 |
+
for _ in range(max_len):
|
242 |
+
last_char = prime[-1:]
|
243 |
+
out = model(last_char).squeeze()
|
244 |
+
|
245 |
+
out = torch.exp(out/temp)
|
246 |
+
dist = out / torch.sum(out)
|
247 |
+
|
248 |
+
# Sample from distribution
|
249 |
+
next_char_idx = torch.multinomial(dist, 1).item()
|
250 |
+
next_char = char_idx[next_char_idx]
|
251 |
+
|
252 |
+
creation += next_char
|
253 |
+
prime = torch.cat([prime, torch.tensor([next_char_idx])], dim=0)
|
254 |
+
|
255 |
+
if creation[-5:] == '<end>':
|
256 |
+
break
|
257 |
+
|
258 |
+
return creation
|
259 |
+
|
260 |
+
def main():
|
261 |
+
"""Main execution function."""
|
262 |
+
# Set up configuration and data
|
263 |
+
global model, data_loader
|
264 |
+
config = Config()
|
265 |
+
data_loader = DataLoader(config.INPUT_FILE, config)
|
266 |
+
|
267 |
+
# Model setup
|
268 |
+
in_size = out_size = len(data_loader.char_idx)
|
269 |
+
model = MusicRNN(
|
270 |
+
in_size,
|
271 |
+
config.HIDDEN_SIZE,
|
272 |
+
out_size,
|
273 |
+
config.MODEL_TYPE,
|
274 |
+
config.NUM_LAYERS,
|
275 |
+
config.DROPOUT_P
|
276 |
+
)
|
277 |
+
|
278 |
+
# Optimizer and loss
|
279 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=config.LR)
|
280 |
+
loss_function = nn.CrossEntropyLoss()
|
281 |
+
|
282 |
+
# Train the model
|
283 |
+
losses, v_losses = train_model(config, data_loader, model, optimizer, loss_function)
|
284 |
+
|
285 |
+
# Plot losses
|
286 |
+
plot_losses(losses, v_losses)
|
287 |
+
save_vocab(data_loader)
|
288 |
+
|
289 |
+
# Generate a song
|
290 |
+
generated_song = generate_song(model, data_loader)
|
291 |
+
print("Generated Song:", generated_song)
|
292 |
+
|
293 |
+
if __name__ == "__main__":
|
294 |
+
main()
|
utils.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
|
7 |
+
# Only do the function below if verbose
|
8 |
+
def logger(verbose):
|
9 |
+
def log(*msg):
|
10 |
+
if verbose: print(*msg)
|
11 |
+
return log
|
12 |
+
|
13 |
+
|
14 |
+
last_time = time.time()
|
15 |
+
begin_time = last_time
|
16 |
+
|
17 |
+
def progress_bar(current, total, msg=None):
|
18 |
+
global last_time, begin_time
|
19 |
+
if current == 0:
|
20 |
+
begin_time = time.time() # Reset for new bar.
|
21 |
+
cur_time = time.time()
|
22 |
+
step_time = cur_time - last_time
|
23 |
+
last_time = cur_time
|
24 |
+
tot_time = cur_time - begin_time
|
25 |
+
|
26 |
+
L = []
|
27 |
+
L.append(' Step: %s' % format_time(step_time))
|
28 |
+
L.append(' | Tot: %s' % format_time(tot_time))
|
29 |
+
if msg:
|
30 |
+
L.append(' | ' + msg)
|
31 |
+
msg = ''.join(L)
|
32 |
+
sys.stdout.write(msg)
|
33 |
+
sys.stdout.write('\r')
|
34 |
+
#if current < total-1:
|
35 |
+
#
|
36 |
+
#else:
|
37 |
+
#sys.stdout.write('\n')
|
38 |
+
sys.stdout.flush()
|
39 |
+
|
40 |
+
def format_time(seconds):
|
41 |
+
days = int(seconds / 3600/24)
|
42 |
+
seconds = seconds - days*3600*24
|
43 |
+
hours = int(seconds / 3600)
|
44 |
+
seconds = seconds - hours*3600
|
45 |
+
minutes = int(seconds / 60)
|
46 |
+
seconds = seconds - minutes*60
|
47 |
+
secondsf = int(seconds)
|
48 |
+
seconds = seconds - secondsf
|
49 |
+
millis = int(seconds*1000)
|
50 |
+
|
51 |
+
f = ''
|
52 |
+
i = 1
|
53 |
+
if days > 0:
|
54 |
+
f += str(days) + 'D'
|
55 |
+
i += 1
|
56 |
+
if hours > 0 and i <= 2:
|
57 |
+
f += str(hours) + 'h'
|
58 |
+
i += 1
|
59 |
+
if minutes > 0 and i <= 2:
|
60 |
+
f += str(minutes) + 'm'
|
61 |
+
i += 1
|
62 |
+
if secondsf > 0 and i <= 2:
|
63 |
+
f += str(secondsf) + 's'
|
64 |
+
i += 1
|
65 |
+
if millis > 0 and i <= 2:
|
66 |
+
f += str(millis) + 'ms'
|
67 |
+
i += 1
|
68 |
+
if f == '':
|
69 |
+
f = '0ms'
|
70 |
+
return f
|
71 |
+
|
72 |
+
def save_vocab(data_loader, vocab_filename="config/vocab.json"):
|
73 |
+
"""Save vocabulary to a JSON file."""
|
74 |
+
vocab = {
|
75 |
+
'char_idx': data_loader.char_idx,
|
76 |
+
'char_list': data_loader.char_list
|
77 |
+
}
|
78 |
+
with open(vocab_filename, 'w') as f:
|
79 |
+
json.dump(vocab, f)
|
80 |
+
|
81 |
+
def load_vocab(vocab_filename='config/vocab.json'):
|
82 |
+
with open(vocab_filename, 'r') as f:
|
83 |
+
vocab = json.load(f)
|
84 |
+
return vocab['char_idx'], vocab['char_list']
|
85 |
+
|
86 |
+
def seq_to_tensor(seq, char_idx):
|
87 |
+
"""Convert sequence to tensor."""
|
88 |
+
out = torch.zeros(len(seq)).long()
|
89 |
+
for i, c in enumerate(seq):
|
90 |
+
out[i] = char_idx.index(c)
|
91 |
+
return out
|