kjysmu commited on
Commit
6ad6801
·
verified ·
1 Parent(s): 2a73849

Upload 22 files

Browse files
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (137 Bytes). View file
 
utils/__pycache__/btc_model.cpython-310.pyc ADDED
Binary file (5.19 kB). View file
 
utils/__pycache__/constants.cpython-310.pyc ADDED
Binary file (574 Bytes). View file
 
utils/__pycache__/custom_early_stopping.cpython-310.pyc ADDED
Binary file (1.77 kB). View file
 
utils/__pycache__/hparams.cpython-310.pyc ADDED
Binary file (1.69 kB). View file
 
utils/__pycache__/logger.cpython-310.pyc ADDED
Binary file (1.87 kB). View file
 
utils/__pycache__/mert.cpython-310.pyc ADDED
Binary file (1.56 kB). View file
 
utils/__pycache__/mir_eval_modules.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
utils/__pycache__/transformer_modules.cpython-310.pyc ADDED
Binary file (9.98 kB). View file
 
utils/btc_model.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.transformer_modules import *
2
+ from utils.transformer_modules import _gen_timing_signal, _gen_bias_mask
3
+ from utils.hparams import HParams
4
+
5
+ use_cuda = torch.cuda.is_available()
6
+
7
+ class self_attention_block(nn.Module):
8
+ def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads,
9
+ bias_mask=None, layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0, attention_map=False):
10
+ super(self_attention_block, self).__init__()
11
+
12
+ self.attention_map = attention_map
13
+ self.multi_head_attention = MultiHeadAttention(hidden_size, total_key_depth, total_value_depth,hidden_size, num_heads, bias_mask, attention_dropout, attention_map)
14
+ self.positionwise_convolution = PositionwiseFeedForward(hidden_size, filter_size, hidden_size, layer_config='cc', padding='both', dropout=relu_dropout)
15
+ self.dropout = nn.Dropout(layer_dropout)
16
+ self.layer_norm_mha = LayerNorm(hidden_size)
17
+ self.layer_norm_ffn = LayerNorm(hidden_size)
18
+
19
+ def forward(self, inputs):
20
+ x = inputs
21
+
22
+ # Layer Normalization
23
+ x_norm = self.layer_norm_mha(x)
24
+
25
+ # Multi-head attention
26
+ if self.attention_map is True:
27
+ y, weights = self.multi_head_attention(x_norm, x_norm, x_norm)
28
+ else:
29
+ y = self.multi_head_attention(x_norm, x_norm, x_norm)
30
+
31
+ # Dropout and residual
32
+ x = self.dropout(x + y)
33
+
34
+ # Layer Normalization
35
+ x_norm = self.layer_norm_ffn(x)
36
+
37
+ # Positionwise Feedforward
38
+ y = self.positionwise_convolution(x_norm)
39
+
40
+ # Dropout and residual
41
+ y = self.dropout(x + y)
42
+
43
+ if self.attention_map is True:
44
+ return y, weights
45
+ return y
46
+
47
+ class bi_directional_self_attention(nn.Module):
48
+ def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads, max_length,
49
+ layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0):
50
+
51
+ super(bi_directional_self_attention, self).__init__()
52
+
53
+ self.weights_list = list()
54
+
55
+ params = (hidden_size,
56
+ total_key_depth or hidden_size,
57
+ total_value_depth or hidden_size,
58
+ filter_size,
59
+ num_heads,
60
+ _gen_bias_mask(max_length),
61
+ layer_dropout,
62
+ attention_dropout,
63
+ relu_dropout,
64
+ True)
65
+
66
+ self.attn_block = self_attention_block(*params)
67
+
68
+ params = (hidden_size,
69
+ total_key_depth or hidden_size,
70
+ total_value_depth or hidden_size,
71
+ filter_size,
72
+ num_heads,
73
+ torch.transpose(_gen_bias_mask(max_length), dim0=2, dim1=3),
74
+ layer_dropout,
75
+ attention_dropout,
76
+ relu_dropout,
77
+ True)
78
+
79
+ self.backward_attn_block = self_attention_block(*params)
80
+
81
+ self.linear = nn.Linear(hidden_size*2, hidden_size)
82
+
83
+ def forward(self, inputs):
84
+ x, list = inputs
85
+
86
+ # Forward Self-attention Block
87
+ encoder_outputs, weights = self.attn_block(x)
88
+ # Backward Self-attention Block
89
+ reverse_outputs, reverse_weights = self.backward_attn_block(x)
90
+ # Concatenation and Fully-connected Layer
91
+ outputs = torch.cat((encoder_outputs, reverse_outputs), dim=2)
92
+ y = self.linear(outputs)
93
+
94
+ # Attention weights for Visualization
95
+ self.weights_list = list
96
+ self.weights_list.append(weights)
97
+ self.weights_list.append(reverse_weights)
98
+ return y, self.weights_list
99
+
100
+ class bi_directional_self_attention_layers(nn.Module):
101
+ def __init__(self, embedding_size, hidden_size, num_layers, num_heads, total_key_depth, total_value_depth,
102
+ filter_size, max_length=100, input_dropout=0.0, layer_dropout=0.0,
103
+ attention_dropout=0.0, relu_dropout=0.0):
104
+ super(bi_directional_self_attention_layers, self).__init__()
105
+
106
+ self.timing_signal = _gen_timing_signal(max_length, hidden_size)
107
+ params = (hidden_size,
108
+ total_key_depth or hidden_size,
109
+ total_value_depth or hidden_size,
110
+ filter_size,
111
+ num_heads,
112
+ max_length,
113
+ layer_dropout,
114
+ attention_dropout,
115
+ relu_dropout)
116
+ self.embedding_proj = nn.Linear(embedding_size, hidden_size, bias=False)
117
+ self.self_attn_layers = nn.Sequential(*[bi_directional_self_attention(*params) for l in range(num_layers)])
118
+ self.layer_norm = LayerNorm(hidden_size)
119
+ self.input_dropout = nn.Dropout(input_dropout)
120
+
121
+ def forward(self, inputs):
122
+ # Add input dropout
123
+ x = self.input_dropout(inputs)
124
+
125
+ # Project to hidden size
126
+ x = self.embedding_proj(x)
127
+
128
+ # Add timing signal
129
+ x += self.timing_signal[:, :inputs.shape[1], :].type_as(inputs.data)
130
+
131
+ # A Stack of Bi-directional Self-attention Layers
132
+ y, weights_list = self.self_attn_layers((x, []))
133
+
134
+ # Layer Normalization
135
+ y = self.layer_norm(y)
136
+ return y, weights_list
137
+
138
+ class BTC_model(nn.Module):
139
+ def __init__(self, config):
140
+ super(BTC_model, self).__init__()
141
+
142
+ self.timestep = config['timestep']
143
+ self.probs_out = config['probs_out']
144
+
145
+ params = (config['feature_size'],
146
+ config['hidden_size'],
147
+ config['num_layers'],
148
+ config['num_heads'],
149
+ config['total_key_depth'],
150
+ config['total_value_depth'],
151
+ config['filter_size'],
152
+ config['timestep'],
153
+ config['input_dropout'],
154
+ config['layer_dropout'],
155
+ config['attention_dropout'],
156
+ config['relu_dropout'])
157
+
158
+ self.self_attn_layers = bi_directional_self_attention_layers(*params)
159
+ self.output_layer = SoftmaxOutputLayer(hidden_size=config['hidden_size'], output_size=config['num_chords'], probs_out=config['probs_out'])
160
+
161
+ def forward(self, x, labels):
162
+ labels = labels.view(-1, self.timestep)
163
+ # Output of Bi-directional Self-attention Layers
164
+ self_attn_output, weights_list = self.self_attn_layers(x)
165
+
166
+ # return logit values for CRF
167
+ if self.probs_out is True:
168
+ logits = self.output_layer(self_attn_output)
169
+ return logits
170
+
171
+ # Output layer and Soft-max
172
+ prediction,second = self.output_layer(self_attn_output)
173
+ prediction = prediction.view(-1)
174
+ second = second.view(-1)
175
+
176
+ # Loss Calculation
177
+ loss = self.output_layer.loss(self_attn_output, labels)
178
+ return prediction, loss, weights_list, second
179
+
180
+ if __name__ == "__main__":
181
+ config = HParams.load("run_config.yaml")
182
+ device = torch.device("cuda" if use_cuda else "cpu")
183
+
184
+ batch_size = 2
185
+ timestep = 108
186
+ feature_size = 144
187
+ num_chords = 25
188
+
189
+ features = torch.randn(batch_size,timestep,feature_size,requires_grad=True).to(device)
190
+ chords = torch.randint(25,(batch_size*timestep,)).to(device)
191
+
192
+ model = BTC_model(config=config.model).to(device)
193
+
194
+ prediction, loss, weights_list, second = model(features, chords)
195
+ print(prediction.size())
196
+ print(loss)
197
+
198
+
utils/chords.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding: utf-8
2
+ """
3
+ This module contains chord evaluation functionality.
4
+
5
+ It provides the evaluation measures used for the MIREX ACE task, and
6
+ tries to follow [1]_ and [2]_ as closely as possible.
7
+
8
+ Notes
9
+ -----
10
+ This implementation tries to follow the references and their implementation
11
+ (e.g., https://github.com/jpauwels/MusOOEvaluator for [2]_). However, there
12
+ are some known (and possibly some unknown) differences. If you find one not
13
+ listed in the following, please file an issue:
14
+
15
+ - Detected chord segments are adjusted to fit the length of the annotations.
16
+ In particular, this means that, if necessary, filler segments of 'no chord'
17
+ are added at beginnings and ends. This can result in different segmentation
18
+ scores compared to the original implementation.
19
+
20
+ References
21
+ ----------
22
+ .. [1] Christopher Harte, "Towards Automatic Extraction of Harmony Information
23
+ from Music Signals." Dissertation,
24
+ Department for Electronic Engineering, Queen Mary University of London,
25
+ 2010.
26
+ .. [2] Johan Pauwels and Geoffroy Peeters.
27
+ "Evaluating Automatically Estimated Chord Sequences."
28
+ In Proceedings of ICASSP 2013, Vancouver, Canada, 2013.
29
+
30
+ """
31
+
32
+ import numpy as np
33
+ import pandas as pd
34
+ import mir_eval
35
+
36
+
37
+ CHORD_DTYPE = [('root', np.int),
38
+ ('bass', np.int),
39
+ ('intervals', np.int, (12,)),
40
+ ('is_major',np.bool)]
41
+
42
+ CHORD_ANN_DTYPE = [('start', np.float),
43
+ ('end', np.float),
44
+ ('chord', CHORD_DTYPE)]
45
+
46
+ NO_CHORD = (-1, -1, np.zeros(12, dtype=np.int), False)
47
+ UNKNOWN_CHORD = (-1, -1, np.ones(12, dtype=np.int) * -1, False)
48
+
49
+ PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
50
+
51
+
52
+ def idx_to_chord(idx):
53
+ if idx == 24:
54
+ return "-"
55
+ elif idx == 25:
56
+ return u"\u03B5"
57
+
58
+ minmaj = idx % 2
59
+ root = idx // 2
60
+
61
+ return PITCH_CLASS[root] + ("M" if minmaj == 0 else "m")
62
+
63
+ class Chords:
64
+
65
+ def __init__(self):
66
+ self._shorthands = {
67
+ 'maj': self.interval_list('(1,3,5)'),
68
+ 'min': self.interval_list('(1,b3,5)'),
69
+ 'dim': self.interval_list('(1,b3,b5)'),
70
+ 'aug': self.interval_list('(1,3,#5)'),
71
+ 'maj7': self.interval_list('(1,3,5,7)'),
72
+ 'min7': self.interval_list('(1,b3,5,b7)'),
73
+ '7': self.interval_list('(1,3,5,b7)'),
74
+ '6': self.interval_list('(1,6)'), # custom
75
+ '5': self.interval_list('(1,5)'),
76
+ '4': self.interval_list('(1,4)'), # custom
77
+ '1': self.interval_list('(1)'),
78
+ 'dim7': self.interval_list('(1,b3,b5,bb7)'),
79
+ 'hdim7': self.interval_list('(1,b3,b5,b7)'),
80
+ 'minmaj7': self.interval_list('(1,b3,5,7)'),
81
+ 'maj6': self.interval_list('(1,3,5,6)'),
82
+ 'min6': self.interval_list('(1,b3,5,6)'),
83
+ '9': self.interval_list('(1,3,5,b7,9)'),
84
+ 'maj9': self.interval_list('(1,3,5,7,9)'),
85
+ 'min9': self.interval_list('(1,b3,5,b7,9)'),
86
+ 'sus2': self.interval_list('(1,2,5)'),
87
+ 'sus4': self.interval_list('(1,4,5)'),
88
+ '11': self.interval_list('(1,3,5,b7,9,11)'),
89
+ 'min11': self.interval_list('(1,b3,5,b7,9,11)'),
90
+ '13': self.interval_list('(1,3,5,b7,13)'),
91
+ 'maj13': self.interval_list('(1,3,5,7,13)'),
92
+ 'min13': self.interval_list('(1,b3,5,b7,13)')
93
+ }
94
+
95
+ def chords(self, labels):
96
+
97
+ """
98
+ Transform a list of chord labels into an array of internal numeric
99
+ representations.
100
+
101
+ Parameters
102
+ ----------
103
+ labels : list
104
+ List of chord labels (str).
105
+
106
+ Returns
107
+ -------
108
+ chords : numpy.array
109
+ Structured array with columns 'root', 'bass', and 'intervals',
110
+ containing a numeric representation of chords.
111
+
112
+ """
113
+ crds = np.zeros(len(labels), dtype=CHORD_DTYPE)
114
+ cache = {}
115
+ for i, lbl in enumerate(labels):
116
+ cv = cache.get(lbl, None)
117
+ if cv is None:
118
+ cv = self.chord(lbl)
119
+ cache[lbl] = cv
120
+ crds[i] = cv
121
+
122
+ return crds
123
+
124
+ def label_error_modify(self, label):
125
+ if label == 'Emin/4': label = 'E:min/4'
126
+ elif label == 'A7/3': label = 'A:7/3'
127
+ elif label == 'Bb7/3': label = 'Bb:7/3'
128
+ elif label == 'Bb7/5': label = 'Bb:7/5'
129
+ elif label.find(':') == -1:
130
+ if label.find('min') != -1:
131
+ label = label[:label.find('min')] + ':' + label[label.find('min'):]
132
+ return label
133
+
134
+ def chord(self, label):
135
+ """
136
+ Transform a chord label into the internal numeric represenation of
137
+ (root, bass, intervals array).
138
+
139
+ Parameters
140
+ ----------
141
+ label : str
142
+ Chord label.
143
+
144
+ Returns
145
+ -------
146
+ chord : tuple
147
+ Numeric representation of the chord: (root, bass, intervals array).
148
+
149
+ """
150
+
151
+ try:
152
+ is_major = False
153
+
154
+ if label == 'N':
155
+ return NO_CHORD
156
+ if label == 'X':
157
+ return UNKNOWN_CHORD
158
+
159
+ label = self.label_error_modify(label)
160
+
161
+ c_idx = label.find(':')
162
+ s_idx = label.find('/')
163
+
164
+ if c_idx == -1:
165
+ quality_str = 'maj'
166
+ if s_idx == -1:
167
+ root_str = label
168
+ bass_str = ''
169
+ else:
170
+ root_str = label[:s_idx]
171
+ bass_str = label[s_idx + 1:]
172
+ else:
173
+ root_str = label[:c_idx]
174
+ if s_idx == -1:
175
+ quality_str = label[c_idx + 1:]
176
+ bass_str = ''
177
+ else:
178
+ quality_str = label[c_idx + 1:s_idx]
179
+ bass_str = label[s_idx + 1:]
180
+
181
+ root = self.pitch(root_str)
182
+ bass = self.interval(bass_str) if bass_str else 0
183
+ ivs = self.chord_intervals(quality_str)
184
+ ivs[bass] = 1
185
+
186
+ if 'min' in quality_str:
187
+ is_major = False
188
+ else:
189
+ is_major = True
190
+
191
+ except Exception as e:
192
+ print(e, label)
193
+
194
+ return root, bass, ivs, is_major
195
+
196
+ _l = [0, 1, 1, 0, 1, 1, 1]
197
+ _chroma_id = (np.arange(len(_l) * 2) + 1) + np.array(_l + _l).cumsum() - 1
198
+
199
+ def modify(self, base_pitch, modifier):
200
+ """
201
+ Modify a pitch class in integer representation by a given modifier string.
202
+
203
+ A modifier string can be any sequence of 'b' (one semitone down)
204
+ and '#' (one semitone up).
205
+
206
+ Parameters
207
+ ----------
208
+ base_pitch : int
209
+ Pitch class as integer.
210
+ modifier : str
211
+ String of modifiers ('b' or '#').
212
+
213
+ Returns
214
+ -------
215
+ modified_pitch : int
216
+ Modified root note.
217
+
218
+ """
219
+ for m in modifier:
220
+ if m == 'b':
221
+ base_pitch -= 1
222
+ elif m == '#':
223
+ base_pitch += 1
224
+ else:
225
+ raise ValueError('Unknown modifier: {}'.format(m))
226
+ return base_pitch
227
+
228
+ def pitch(self, pitch_str):
229
+ """
230
+ Convert a string representation of a pitch class (consisting of root
231
+ note and modifiers) to an integer representation.
232
+
233
+ Parameters
234
+ ----------
235
+ pitch_str : str
236
+ String representation of a pitch class.
237
+
238
+ Returns
239
+ -------
240
+ pitch : int
241
+ Integer representation of a pitch class.
242
+
243
+ """
244
+ return self.modify(self._chroma_id[(ord(pitch_str[0]) - ord('C')) % 7],
245
+ pitch_str[1:]) % 12
246
+
247
+ def interval(self, interval_str):
248
+ """
249
+ Convert a string representation of a musical interval into a pitch class
250
+ (e.g. a minor seventh 'b7' into 10, because it is 10 semitones above its
251
+ base note).
252
+
253
+ Parameters
254
+ ----------
255
+ interval_str : str
256
+ Musical interval.
257
+
258
+ Returns
259
+ -------
260
+ pitch_class : int
261
+ Number of semitones to base note of interval.
262
+
263
+ """
264
+ for i, c in enumerate(interval_str):
265
+ if c.isdigit():
266
+ return self.modify(self._chroma_id[int(interval_str[i:]) - 1],
267
+ interval_str[:i]) % 12
268
+
269
+ def interval_list(self, intervals_str, given_pitch_classes=None):
270
+ """
271
+ Convert a list of intervals given as string to a binary pitch class
272
+ representation. For example, 'b3, 5' would become
273
+ [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0].
274
+
275
+ Parameters
276
+ ----------
277
+ intervals_str : str
278
+ List of intervals as comma-separated string (e.g. 'b3, 5').
279
+ given_pitch_classes : None or numpy array
280
+ If None, start with empty pitch class array, if numpy array of length
281
+ 12, this array will be modified.
282
+
283
+ Returns
284
+ -------
285
+ pitch_classes : numpy array
286
+ Binary pitch class representation of intervals.
287
+
288
+ """
289
+ if given_pitch_classes is None:
290
+ given_pitch_classes = np.zeros(12, dtype=np.int)
291
+ for int_def in intervals_str[1:-1].split(','):
292
+ int_def = int_def.strip()
293
+ if int_def[0] == '*':
294
+ given_pitch_classes[self.interval(int_def[1:])] = 0
295
+ else:
296
+ given_pitch_classes[self.interval(int_def)] = 1
297
+ return given_pitch_classes
298
+
299
+ # mapping of shorthand interval notations to the actual interval representation
300
+
301
+ def chord_intervals(self, quality_str):
302
+ """
303
+ Convert a chord quality string to a pitch class representation. For
304
+ example, 'maj' becomes [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0].
305
+
306
+ Parameters
307
+ ----------
308
+ quality_str : str
309
+ String defining the chord quality.
310
+
311
+ Returns
312
+ -------
313
+ pitch_classes : numpy array
314
+ Binary pitch class representation of chord quality.
315
+
316
+ """
317
+ list_idx = quality_str.find('(')
318
+ if list_idx == -1:
319
+ return self._shorthands[quality_str].copy()
320
+ if list_idx != 0:
321
+ ivs = self._shorthands[quality_str[:list_idx]].copy()
322
+ else:
323
+ ivs = np.zeros(12, dtype=np.int)
324
+
325
+
326
+ return self.interval_list(quality_str[list_idx:], ivs)
327
+
328
+ def load_chords(self, filename):
329
+ """
330
+ Load chords from a text file.
331
+
332
+ The chord must follow the syntax defined in [1]_.
333
+
334
+ Parameters
335
+ ----------
336
+ filename : str
337
+ File containing chord segments.
338
+
339
+ Returns
340
+ -------
341
+ crds : numpy structured array
342
+ Structured array with columns "start", "end", and "chord",
343
+ containing the beginning, end, and chord definition of chord
344
+ segments.
345
+
346
+ References
347
+ ----------
348
+ .. [1] Christopher Harte, "Towards Automatic Extraction of Harmony
349
+ Information from Music Signals." Dissertation,
350
+ Department for Electronic Engineering, Queen Mary University of
351
+ London, 2010.
352
+
353
+ """
354
+ start, end, chord_labels = [], [], []
355
+ with open(filename, 'r') as f:
356
+ for line in f:
357
+ if line:
358
+
359
+ splits = line.split()
360
+ if len(splits) == 3:
361
+
362
+ s = splits[0]
363
+ e = splits[1]
364
+ l = splits[2]
365
+
366
+ start.append(float(s))
367
+ end.append(float(e))
368
+ chord_labels.append(l)
369
+
370
+ crds = np.zeros(len(start), dtype=CHORD_ANN_DTYPE)
371
+ crds['start'] = start
372
+ crds['end'] = end
373
+ crds['chord'] = self.chords(chord_labels)
374
+
375
+ return crds
376
+
377
+ def reduce_to_triads(self, chords, keep_bass=False):
378
+ """
379
+ Reduce chords to triads.
380
+
381
+ The function follows the reduction rules implemented in [1]_. If a chord
382
+ chord does not contain a third, major second or fourth, it is reduced to
383
+ a power chord. If it does not contain neither a third nor a fifth, it is
384
+ reduced to a single note "chord".
385
+
386
+ Parameters
387
+ ----------
388
+ chords : numpy structured array
389
+ Chords to be reduced.
390
+ keep_bass : bool
391
+ Indicates whether to keep the bass note or set it to 0.
392
+
393
+ Returns
394
+ -------
395
+ reduced_chords : numpy structured array
396
+ Chords reduced to triads.
397
+
398
+ References
399
+ ----------
400
+ .. [1] Johan Pauwels and Geoffroy Peeters.
401
+ "Evaluating Automatically Estimated Chord Sequences."
402
+ In Proceedings of ICASSP 2013, Vancouver, Canada, 2013.
403
+
404
+ """
405
+ unison = chords['intervals'][:, 0].astype(bool)
406
+ maj_sec = chords['intervals'][:, 2].astype(bool)
407
+ min_third = chords['intervals'][:, 3].astype(bool)
408
+ maj_third = chords['intervals'][:, 4].astype(bool)
409
+ perf_fourth = chords['intervals'][:, 5].astype(bool)
410
+ dim_fifth = chords['intervals'][:, 6].astype(bool)
411
+ perf_fifth = chords['intervals'][:, 7].astype(bool)
412
+ aug_fifth = chords['intervals'][:, 8].astype(bool)
413
+ no_chord = (chords['intervals'] == NO_CHORD[-1]).all(axis=1)
414
+
415
+ reduced_chords = chords.copy()
416
+ ivs = reduced_chords['intervals']
417
+
418
+ ivs[~no_chord] = self.interval_list('(1)')
419
+ ivs[unison & perf_fifth] = self.interval_list('(1,5)')
420
+ ivs[~perf_fourth & maj_sec] = self._shorthands['sus2']
421
+ ivs[perf_fourth & ~maj_sec] = self._shorthands['sus4']
422
+
423
+ ivs[min_third] = self._shorthands['min']
424
+ ivs[min_third & aug_fifth & ~perf_fifth] = self.interval_list('(1,b3,#5)')
425
+ ivs[min_third & dim_fifth & ~perf_fifth] = self._shorthands['dim']
426
+
427
+ ivs[maj_third] = self._shorthands['maj']
428
+ ivs[maj_third & dim_fifth & ~perf_fifth] = self.interval_list('(1,3,b5)')
429
+ ivs[maj_third & aug_fifth & ~perf_fifth] = self._shorthands['aug']
430
+
431
+ if not keep_bass:
432
+ reduced_chords['bass'] = 0
433
+ else:
434
+ # remove bass notes if they are not part of the intervals anymore
435
+ reduced_chords['bass'] *= ivs[range(len(reduced_chords)),
436
+ reduced_chords['bass']]
437
+ # keep -1 in bass for no chords
438
+ reduced_chords['bass'][no_chord] = -1
439
+
440
+ return reduced_chords
441
+
442
+ def convert_to_id(self, root, is_major):
443
+ if root == -1:
444
+ return 24
445
+ else:
446
+ if is_major:
447
+ return root * 2
448
+ else:
449
+ return root * 2 + 1
450
+
451
+ def get_converted_chord(self, filename):
452
+ loaded_chord = self.load_chords(filename)
453
+ triads = self.reduce_to_triads(loaded_chord['chord'])
454
+
455
+ df = self.assign_chord_id(triads)
456
+ df['start'] = loaded_chord['start']
457
+ df['end'] = loaded_chord['end']
458
+
459
+ return df
460
+
461
+ def assign_chord_id(self, entry):
462
+ # maj, min chord only
463
+ # if you want to add other chord, change this part and get_converted_chord(reduce_to_triads)
464
+ df = pd.DataFrame(data=entry[['root', 'is_major']])
465
+ df['chord_id'] = df.apply(lambda row: self.convert_to_id(row['root'], row['is_major']), axis=1)
466
+ return df
467
+
468
+ def convert_to_id_voca(self, root, quality):
469
+ if root == -1:
470
+ return 169
471
+ else:
472
+ if quality == 'min':
473
+ return root * 14
474
+ elif quality == 'maj':
475
+ return root * 14 + 1
476
+ elif quality == 'dim':
477
+ return root * 14 + 2
478
+ elif quality == 'aug':
479
+ return root * 14 + 3
480
+ elif quality == 'min6':
481
+ return root * 14 + 4
482
+ elif quality == 'maj6':
483
+ return root * 14 + 5
484
+ elif quality == 'min7':
485
+ return root * 14 + 6
486
+ elif quality == 'minmaj7':
487
+ return root * 14 + 7
488
+ elif quality == 'maj7':
489
+ return root * 14 + 8
490
+ elif quality == '7':
491
+ return root * 14 + 9
492
+ elif quality == 'dim7':
493
+ return root * 14 + 10
494
+ elif quality == 'hdim7':
495
+ return root * 14 + 11
496
+ elif quality == 'sus2':
497
+ return root * 14 + 12
498
+ elif quality == 'sus4':
499
+ return root * 14 + 13
500
+ else:
501
+ return 168
502
+
503
+ def get_converted_chord_voca(self, filename):
504
+ loaded_chord = self.load_chords(filename)
505
+ triads = self.reduce_to_triads(loaded_chord['chord'])
506
+ df = pd.DataFrame(data=triads[['root', 'is_major']])
507
+
508
+ (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(filename)
509
+ ref_labels = self.lab_file_error_modify(ref_labels)
510
+ idxs = list()
511
+ for i in ref_labels:
512
+ chord_root, quality, scale_degrees, bass = mir_eval.chord.split(i, reduce_extended_chords=True)
513
+ root, bass, ivs, is_major = self.chord(i)
514
+ idxs.append(self.convert_to_id_voca(root=root, quality=quality))
515
+ df['chord_id'] = idxs
516
+
517
+ df['start'] = loaded_chord['start']
518
+ df['end'] = loaded_chord['end']
519
+
520
+ return df
521
+
522
+ def lab_file_error_modify(self, ref_labels):
523
+ for i in range(len(ref_labels)):
524
+ if ref_labels[i][-2:] == ':4':
525
+ ref_labels[i] = ref_labels[i].replace(':4', ':sus4')
526
+ elif ref_labels[i][-2:] == ':6':
527
+ ref_labels[i] = ref_labels[i].replace(':6', ':maj6')
528
+ elif ref_labels[i][-4:] == ':6/2':
529
+ ref_labels[i] = ref_labels[i].replace(':6/2', ':maj6/2')
530
+ elif ref_labels[i] == 'Emin/4':
531
+ ref_labels[i] = 'E:min/4'
532
+ elif ref_labels[i] == 'A7/3':
533
+ ref_labels[i] = 'A:7/3'
534
+ elif ref_labels[i] == 'Bb7/3':
535
+ ref_labels[i] = 'Bb:7/3'
536
+ elif ref_labels[i] == 'Bb7/5':
537
+ ref_labels[i] = 'Bb:7/5'
538
+ elif ref_labels[i].find(':') == -1:
539
+ if ref_labels[i].find('min') != -1:
540
+ ref_labels[i] = ref_labels[i][:ref_labels[i].find('min')] + ':' + ref_labels[i][ref_labels[i].find('min'):]
541
+ return ref_labels
542
+
utils/constants.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+ ### DEPRECATED - use hydra conf instead ######
6
+
7
+ import torch
8
+ import os
9
+
10
+ # --------------------------------------- #
11
+ VERSION = "1.24"
12
+
13
+ # --------------------------------------- #
14
+ ENCODER = "MERT"
15
+
16
+ # - - -
17
+ # MERT
18
+ # M2L
19
+ # LIBROSA
20
+ # - - -
21
+ # Encodec
22
+ # DAC
23
+
24
+ # --------------------------------------- #
25
+
26
+ SEGMENT = "all"
27
+ # all
28
+ # f10s - first 10s
29
+ # f30s - first 30s
30
+ # 10s
31
+ # 30s
32
+
33
+ AGGREGATION_METHOD = "mean"
34
+ # mean
35
+ # median
36
+ # 80th_percentile
37
+ # max
38
+
39
+ # --------------------------------------- #
40
+ CLASSIFIER = "linear-mt"
41
+ # transformer
42
+ # linear
43
+ # linear-small
44
+ # linear-multitask
45
+ # linear-small-multitask
46
+ # linear-mt (mert-like classifier)
47
+ #
48
+ # --------------------------------------- #
49
+ CHECKPOINT = "tb_logs/train_audio_classification/version_110/checkpoints/21-0.1202.ckpt"
50
+ # --------------------------------------- #
51
+ BATCH_SIZE = 8
52
+ N_EPOCHS = 50
53
+
54
+ # --------------------------------------- #
55
+ GENRE_CLASS_SIZE = 87
56
+ MOOD_CLASS_SIZE = 56
57
+ INSTR_CLASS_SIZE = 40
58
+ DAC_LATENTS_SIZE = 72
59
+ DAC_RVQ_SIZE = 9
60
+ # --------------------------------------- #
utils/custom_early_stopping.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # custom_early_stopping.py
2
+
3
+ import pytorch_lightning as pl
4
+ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
5
+
6
+
7
+ class MultiMetricEarlyStopping(EarlyStopping):
8
+ def __init__(self, monitor_mood, monitor_va, patience, min_delta, mode="min"):
9
+ super().__init__(monitor=None, patience=patience, min_delta=min_delta, mode=mode)
10
+ self.monitor_mood = monitor_mood
11
+ self.monitor_va = monitor_va
12
+ self.patience = patience
13
+ self.min_delta = min_delta
14
+ self.mode = mode
15
+
16
+ # Initialize tracking variables
17
+ self.wait_mood = 0
18
+ self.wait_va = 0
19
+ self.best_mood = float('inf') if mode == "min" else -float('inf')
20
+ self.best_va = float('inf') if mode == "min" else -float('inf')
21
+
22
+ def _check_stop(self, current, best, wait):
23
+ if self.mode == "min" and current < best - self.min_delta:
24
+ return current, 0
25
+ elif self.mode == "max" and current > best + self.min_delta:
26
+ return current, 0
27
+ else:
28
+ return best, wait + 1
29
+
30
+ def on_validation_epoch_end(self, trainer, pl_module):
31
+ logs = trainer.callback_metrics
32
+
33
+ if self.monitor_mood not in logs or self.monitor_va not in logs:
34
+ raise RuntimeError(f"Metrics {self.monitor_mood} or {self.monitor_va} not available.")
35
+
36
+ # Get current values for the monitored metrics
37
+ current_mood = logs[self.monitor_mood].item()
38
+ current_va = logs[self.monitor_va].item()
39
+
40
+ # Check stopping conditions for both metrics
41
+ self.best_mood, self.wait_mood = self._check_stop(current_mood, self.best_mood, self.wait_mood)
42
+ self.best_va, self.wait_va = self._check_stop(current_va, self.best_va, self.wait_va)
43
+
44
+ # Stop if patience exceeded for both metrics
45
+ if self.wait_mood > self.patience and self.wait_va > self.patience:
46
+ self.stopped_epoch = trainer.current_epoch
47
+ trainer.should_stop = True
48
+
49
+ # # custom_early_stopping.py
50
+
51
+ # import pytorch_lightning as pl
52
+ # from pytorch_lightning.callbacks.early_stopping import EarlyStopping
53
+
54
+ # class MultiMetricEarlyStopping(EarlyStopping):
55
+ # def __init__(self, monitor_mood: str, monitor_va: str, patience: int = 10, min_delta: float = 0.0, mode: str = "min"):
56
+ # super().__init__(monitor=None, patience=patience, min_delta=min_delta, mode=mode)
57
+ # self.monitor_mood = monitor_mood
58
+ # self.monitor_va = monitor_va
59
+ # self.wait_mood = 0
60
+ # self.wait_va = 0
61
+ # self.best_mood_score = None
62
+ # self.best_va_score = None
63
+ # self.patience = patience
64
+ # self.stopped_epoch = 0
65
+
66
+ # def on_validation_end(self, trainer, pl_module):
67
+ # current_mood = trainer.callback_metrics.get(self.monitor_mood)
68
+ # current_va = trainer.callback_metrics.get(self.monitor_va)
69
+
70
+ # # Check if current_mood improved
71
+ # if self.best_mood_score is None or self._compare(current_mood, self.best_mood_score):
72
+ # self.best_mood_score = current_mood
73
+ # self.wait_mood = 0
74
+ # else:
75
+ # self.wait_mood += 1
76
+
77
+ # # Check if current_va improved
78
+ # if self.best_va_score is None or self._compare(current_va, self.best_va_score):
79
+ # self.best_va_score = current_va
80
+ # self.wait_va = 0
81
+ # else:
82
+ # self.wait_va += 1
83
+
84
+ # # If both metrics are stagnant for patience epochs, stop training
85
+ # if self.wait_mood >= self.patience and self.wait_va >= self.patience:
86
+ # self.stopped_epoch = trainer.current_epoch
87
+ # trainer.should_stop = True
88
+
89
+ # def _compare(self, current, best):
90
+ # if self.mode == "min":
91
+ # return current < best - self.min_delta
92
+ # else:
93
+ # return current > best + self.min_delta
utils/hparams.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+
3
+
4
+ # TODO: add function should be changed
5
+ class HParams(object):
6
+ # Hyperparameter class using yaml
7
+ def __init__(self, **kwargs):
8
+ self.__dict__ = kwargs
9
+
10
+ def add(self, **kwargs):
11
+ # change is needed - if key is existed, do not update.
12
+ self.__dict__.update(kwargs)
13
+
14
+ def update(self, **kwargs):
15
+ self.__dict__.update(kwargs)
16
+ return self
17
+
18
+ def save(self, path):
19
+ with open(path, 'w') as f:
20
+ yaml.dump(self.__dict__, f)
21
+ return self
22
+
23
+ def __repr__(self):
24
+ return '\nHyperparameters:\n' + '\n'.join([' {}={}'.format(k, v) for k, v in self.__dict__.items()])
25
+
26
+ @classmethod
27
+ def load(cls, path):
28
+ with open(path, 'r') as f:
29
+ return cls(**yaml.load(f, Loader=yaml.FullLoader))
30
+
31
+
32
+ if __name__ == '__main__':
33
+ hparams = HParams.load('hparams.yaml')
34
+ print(hparams)
35
+ d = {"MemoryNetwork": 0, "c": 1}
36
+ hparams.add(**d)
37
+ print(hparams)
utils/logger.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import time
5
+
6
+
7
+ project_name = os.getcwd().split('/')[-1]
8
+ _logger = logging.getLogger(project_name)
9
+ _logger.addHandler(logging.StreamHandler())
10
+
11
+ def _log_prefix():
12
+
13
+ # Returns (filename, line number) for the stack frame.
14
+ def _get_file_line():
15
+
16
+ # pylint: disable=protected-access
17
+ # noinspection PyProtectedMember
18
+ f = sys._getframe()
19
+ # pylint: enable=protected-access
20
+ our_file = f.f_code.co_filename
21
+ f = f.f_back
22
+ while f:
23
+ code = f.f_code
24
+ if code.co_filename != our_file:
25
+ return code.co_filename, f.f_lineno
26
+ f = f.f_back
27
+ return '<unknown>', 0
28
+
29
+ # current time
30
+ now = time.time()
31
+ now_tuple = time.localtime(now)
32
+ now_millisecond = int(1e3 * (now % 1.0))
33
+
34
+ # current filename and line
35
+ filename, line = _get_file_line()
36
+ basename = os.path.basename(filename)
37
+
38
+ s = '%02d-%02d %02d:%02d:%02d.%03d %s:%d] ' % (
39
+ now_tuple[1], # month
40
+ now_tuple[2], # day
41
+ now_tuple[3], # hour
42
+ now_tuple[4], # min
43
+ now_tuple[5], # sec
44
+ now_millisecond,
45
+ basename,
46
+ line)
47
+
48
+ return s
49
+
50
+
51
+ def logging_verbosity(verbosity=0):
52
+ _logger.setLevel(verbosity)
53
+
54
+
55
+ def debug(msg, *args, **kwargs):
56
+ _logger.debug('D ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)
57
+
58
+
59
+ def info(msg, *args, **kwargs):
60
+ _logger.info('I ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)
61
+
62
+
63
+ def warn(msg, *args, **kwargs):
64
+ _logger.warning('W ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)
65
+
66
+
67
+ def error(msg, *args, **kwargs):
68
+ _logger.error('E ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)
69
+
70
+
71
+ def fatal(msg, *args, **kwargs):
72
+ _logger.fatal('F ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)
utils/mert.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import Wav2Vec2FeatureExtractor, AutoModel
4
+
5
+ class FeatureExtractorMERT:
6
+ def __init__(self, model_name="m-a-p/MERT-v1-95M", device = "None", sr=24000):
7
+ self.model_name = model_name
8
+ self.sr = sr
9
+ if device == "None":
10
+ use_cuda = torch.cuda.is_available()
11
+ device = torch.device("cuda" if use_cuda else "cpu")
12
+ else:
13
+ self.device = device
14
+
15
+ self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True).to(self.device)
16
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained(self.model_name, trust_remote_code=True)
17
+
18
+ def extract_features_from_segment(self, segment, sample_rate, save_path):
19
+ input_audio = segment.float()
20
+ model_inputs = self.processor(input_audio, sampling_rate=sample_rate, return_tensors="pt")
21
+ model_inputs = model_inputs.to(self.device)
22
+
23
+ with torch.no_grad():
24
+ model_outputs = self.model(**model_inputs, output_hidden_states=True)
25
+
26
+ # Stack and process hidden states
27
+ all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()[1:, :, :].unsqueeze(0)
28
+ all_layer_hidden_states = all_layer_hidden_states.mean(dim=2)
29
+ features = all_layer_hidden_states.cpu().detach().numpy()
30
+
31
+ # Save features
32
+ np.save(save_path, features)
utils/mir_eval_modules.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import librosa
3
+ import mir_eval
4
+ import torch
5
+ import os
6
+
7
+ idx2chord = ['C', 'C:min', 'C#', 'C#:min', 'D', 'D:min', 'D#', 'D#:min', 'E', 'E:min', 'F', 'F:min', 'F#',
8
+ 'F#:min', 'G', 'G:min', 'G#', 'G#:min', 'A', 'A:min', 'A#', 'A#:min', 'B', 'B:min', 'N']
9
+
10
+ root_list = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
11
+ quality_list = ['min', 'maj', 'dim', 'aug', 'min6', 'maj6', 'min7', 'minmaj7', 'maj7', '7', 'dim7', 'hdim7', 'sus2', 'sus4']
12
+
13
+ def idx2voca_chord():
14
+ idx2voca_chord = {}
15
+ idx2voca_chord[169] = 'N'
16
+ idx2voca_chord[168] = 'X'
17
+ for i in range(168):
18
+ root = i // 14
19
+ root = root_list[root]
20
+ quality = i % 14
21
+ quality = quality_list[quality]
22
+ if i % 14 != 1:
23
+ chord = root + ':' + quality
24
+ else:
25
+ chord = root
26
+ idx2voca_chord[i] = chord
27
+ return idx2voca_chord
28
+
29
+ def audio_file_to_features(audio_file, config):
30
+ original_wav, sr = librosa.load(audio_file, sr=config.mp3['song_hz'], mono=True)
31
+ currunt_sec_hz = 0
32
+ while len(original_wav) > currunt_sec_hz + config.mp3['song_hz'] * config.mp3['inst_len']:
33
+ start_idx = int(currunt_sec_hz)
34
+ end_idx = int(currunt_sec_hz + config.mp3['song_hz'] * config.mp3['inst_len'])
35
+ tmp = librosa.cqt(original_wav[start_idx:end_idx], sr=sr, n_bins=config.feature['n_bins'], bins_per_octave=config.feature['bins_per_octave'], hop_length=config.feature['hop_length'])
36
+ if start_idx == 0:
37
+ feature = tmp
38
+ else:
39
+ feature = np.concatenate((feature, tmp), axis=1)
40
+ currunt_sec_hz = end_idx
41
+ tmp = librosa.cqt(original_wav[currunt_sec_hz:], sr=sr, n_bins=config.feature['n_bins'], bins_per_octave=config.feature['bins_per_octave'], hop_length=config.feature['hop_length'])
42
+ feature = np.concatenate((feature, tmp), axis=1)
43
+ feature = np.log(np.abs(feature) + 1e-6)
44
+ feature_per_second = config.mp3['inst_len'] / config.model['timestep']
45
+ song_length_second = len(original_wav)/config.mp3['song_hz']
46
+ return feature, feature_per_second, song_length_second
47
+
48
+ # Audio files with format of wav and mp3
49
+ def get_audio_paths(audio_dir):
50
+ return [os.path.join(root, fname) for (root, dir_names, file_names) in os.walk(audio_dir, followlinks=True)
51
+ for fname in file_names if (fname.lower().endswith('.wav') or fname.lower().endswith('.mp3'))]
52
+
53
+ def get_lab_paths(lab_dir):
54
+ return [os.path.join(root, fname) for (root, dir_names, file_names) in os.walk(lab_dir, followlinks=True)
55
+ for fname in file_names if (fname.lower().endswith('.lab'))]
56
+
57
+
58
+ class metrics():
59
+ def __init__(self):
60
+ super(metrics, self).__init__()
61
+ self.score_metrics = ['root', 'thirds', 'triads', 'sevenths', 'tetrads', 'majmin', 'mirex']
62
+ self.score_list_dict = dict()
63
+ for i in self.score_metrics:
64
+ self.score_list_dict[i] = list()
65
+ self.average_score = dict()
66
+
67
+ def score(self, metric, gt_path, est_path):
68
+ if metric == 'root':
69
+ score = self.root_score(gt_path,est_path)
70
+ elif metric == 'thirds':
71
+ score = self.thirds_score(gt_path,est_path)
72
+ elif metric == 'triads':
73
+ score = self.triads_score(gt_path,est_path)
74
+ elif metric == 'sevenths':
75
+ score = self.sevenths_score(gt_path,est_path)
76
+ elif metric == 'tetrads':
77
+ score = self.tetrads_score(gt_path,est_path)
78
+ elif metric == 'majmin':
79
+ score = self.majmin_score(gt_path,est_path)
80
+ elif metric == 'mirex':
81
+ score = self.mirex_score(gt_path,est_path)
82
+ else:
83
+ raise NotImplementedError
84
+ return score
85
+
86
+ def root_score(self, gt_path, est_path):
87
+ (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
88
+ ref_labels = lab_file_error_modify(ref_labels)
89
+ (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
90
+ est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
91
+ ref_intervals.max(), mir_eval.chord.NO_CHORD,
92
+ mir_eval.chord.NO_CHORD)
93
+ (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
94
+ est_intervals, est_labels)
95
+ durations = mir_eval.util.intervals_to_durations(intervals)
96
+ comparisons = mir_eval.chord.root(ref_labels, est_labels)
97
+ score = mir_eval.chord.weighted_accuracy(comparisons, durations)
98
+ return score
99
+
100
+ def thirds_score(self, gt_path, est_path):
101
+ (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
102
+ ref_labels = lab_file_error_modify(ref_labels)
103
+ (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
104
+ est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
105
+ ref_intervals.max(), mir_eval.chord.NO_CHORD,
106
+ mir_eval.chord.NO_CHORD)
107
+ (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
108
+ est_intervals, est_labels)
109
+ durations = mir_eval.util.intervals_to_durations(intervals)
110
+ comparisons = mir_eval.chord.thirds(ref_labels, est_labels)
111
+ score = mir_eval.chord.weighted_accuracy(comparisons, durations)
112
+ return score
113
+
114
+ def triads_score(self, gt_path, est_path):
115
+ (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
116
+ ref_labels = lab_file_error_modify(ref_labels)
117
+ (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
118
+ est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
119
+ ref_intervals.max(), mir_eval.chord.NO_CHORD,
120
+ mir_eval.chord.NO_CHORD)
121
+ (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
122
+ est_intervals, est_labels)
123
+ durations = mir_eval.util.intervals_to_durations(intervals)
124
+ comparisons = mir_eval.chord.triads(ref_labels, est_labels)
125
+ score = mir_eval.chord.weighted_accuracy(comparisons, durations)
126
+ return score
127
+
128
+ def sevenths_score(self, gt_path, est_path):
129
+ (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
130
+ ref_labels = lab_file_error_modify(ref_labels)
131
+ (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
132
+ est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
133
+ ref_intervals.max(), mir_eval.chord.NO_CHORD,
134
+ mir_eval.chord.NO_CHORD)
135
+ (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
136
+ est_intervals, est_labels)
137
+ durations = mir_eval.util.intervals_to_durations(intervals)
138
+ comparisons = mir_eval.chord.sevenths(ref_labels, est_labels)
139
+ score = mir_eval.chord.weighted_accuracy(comparisons, durations)
140
+ return score
141
+
142
+ def tetrads_score(self, gt_path, est_path):
143
+ (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
144
+ ref_labels = lab_file_error_modify(ref_labels)
145
+ (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
146
+ est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
147
+ ref_intervals.max(), mir_eval.chord.NO_CHORD,
148
+ mir_eval.chord.NO_CHORD)
149
+ (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
150
+ est_intervals, est_labels)
151
+ durations = mir_eval.util.intervals_to_durations(intervals)
152
+ comparisons = mir_eval.chord.tetrads(ref_labels, est_labels)
153
+ score = mir_eval.chord.weighted_accuracy(comparisons, durations)
154
+ return score
155
+
156
+ def majmin_score(self, gt_path, est_path):
157
+ (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
158
+ ref_labels = lab_file_error_modify(ref_labels)
159
+ (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
160
+ est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
161
+ ref_intervals.max(), mir_eval.chord.NO_CHORD,
162
+ mir_eval.chord.NO_CHORD)
163
+ (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
164
+ est_intervals, est_labels)
165
+ durations = mir_eval.util.intervals_to_durations(intervals)
166
+ comparisons = mir_eval.chord.majmin(ref_labels, est_labels)
167
+ score = mir_eval.chord.weighted_accuracy(comparisons, durations)
168
+ return score
169
+
170
+ def mirex_score(self, gt_path, est_path):
171
+ (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
172
+ ref_labels = lab_file_error_modify(ref_labels)
173
+ (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
174
+ est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
175
+ ref_intervals.max(), mir_eval.chord.NO_CHORD,
176
+ mir_eval.chord.NO_CHORD)
177
+ (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
178
+ est_intervals, est_labels)
179
+ durations = mir_eval.util.intervals_to_durations(intervals)
180
+ comparisons = mir_eval.chord.mirex(ref_labels, est_labels)
181
+ score = mir_eval.chord.weighted_accuracy(comparisons, durations)
182
+ return score
183
+
184
+ def lab_file_error_modify(ref_labels):
185
+ for i in range(len(ref_labels)):
186
+ if ref_labels[i][-2:] == ':4':
187
+ ref_labels[i] = ref_labels[i].replace(':4', ':sus4')
188
+ elif ref_labels[i][-2:] == ':6':
189
+ ref_labels[i] = ref_labels[i].replace(':6', ':maj6')
190
+ elif ref_labels[i][-4:] == ':6/2':
191
+ ref_labels[i] = ref_labels[i].replace(':6/2', ':maj6/2')
192
+ elif ref_labels[i] == 'Emin/4':
193
+ ref_labels[i] = 'E:min/4'
194
+ elif ref_labels[i] == 'A7/3':
195
+ ref_labels[i] = 'A:7/3'
196
+ elif ref_labels[i] == 'Bb7/3':
197
+ ref_labels[i] = 'Bb:7/3'
198
+ elif ref_labels[i] == 'Bb7/5':
199
+ ref_labels[i] = 'Bb:7/5'
200
+ elif ref_labels[i].find(':') == -1:
201
+ if ref_labels[i].find('min') != -1:
202
+ ref_labels[i] = ref_labels[i][:ref_labels[i].find('min')] + ':' + ref_labels[i][ref_labels[i].find('min'):]
203
+ return ref_labels
204
+
205
+ def root_majmin_score_calculation(valid_dataset, config, mean, std, device, model, model_type, verbose=False):
206
+ valid_song_names = valid_dataset.song_names
207
+ paths = valid_dataset.preprocessor.get_all_files()
208
+
209
+ metrics_ = metrics()
210
+ song_length_list = list()
211
+ for path in paths:
212
+ song_name, lab_file_path, mp3_file_path, _ = path
213
+ if not song_name in valid_song_names:
214
+ continue
215
+ try:
216
+ n_timestep = config.model['timestep']
217
+ feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)
218
+ feature = feature.T
219
+ feature = (feature - mean) / std
220
+ time_unit = feature_per_second
221
+
222
+ num_pad = n_timestep - (feature.shape[0] % n_timestep)
223
+ feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
224
+ num_instance = feature.shape[0] // n_timestep
225
+
226
+ start_time = 0.0
227
+ lines = []
228
+ with torch.no_grad():
229
+ model.eval()
230
+ feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)
231
+ for t in range(num_instance):
232
+ if model_type == 'btc':
233
+ encoder_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])
234
+ prediction, _ = model.output_layer(encoder_output)
235
+ prediction = prediction.squeeze()
236
+ elif model_type == 'cnn' or model_type =='crnn':
237
+ prediction, _, _, _ = model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
238
+ for i in range(n_timestep):
239
+ if t == 0 and i == 0:
240
+ prev_chord = prediction[i].item()
241
+ continue
242
+ if prediction[i].item() != prev_chord:
243
+ lines.append(
244
+ '%.6f %.6f %s\n' % (
245
+ start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))
246
+ start_time = time_unit * (n_timestep * t + i)
247
+ prev_chord = prediction[i].item()
248
+ if t == num_instance - 1 and i + num_pad == n_timestep:
249
+ if start_time != time_unit * (n_timestep * t + i):
250
+ lines.append(
251
+ '%.6f %.6f %s\n' % (
252
+ start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))
253
+ break
254
+ pid = os.getpid()
255
+ tmp_path = 'tmp_' + str(pid) + '.lab'
256
+ with open(tmp_path, 'w') as f:
257
+ for line in lines:
258
+ f.write(line)
259
+
260
+ root_majmin = ['root', 'majmin']
261
+ for m in root_majmin:
262
+ metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))
263
+ song_length_list.append(song_length_second)
264
+ if verbose:
265
+ for m in root_majmin:
266
+ print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))
267
+ except:
268
+ print('song name %s\' lab file error' % song_name)
269
+
270
+ tmp = song_length_list / np.sum(song_length_list)
271
+ for m in root_majmin:
272
+ metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))
273
+
274
+ return metrics_.score_list_dict, song_length_list, metrics_.average_score
275
+
276
+ def root_majmin_score_calculation_crf(valid_dataset, config, mean, std, device, pre_model, model, model_type, verbose=False):
277
+ valid_song_names = valid_dataset.song_names
278
+ paths = valid_dataset.preprocessor.get_all_files()
279
+
280
+ metrics_ = metrics()
281
+ song_length_list = list()
282
+ for path in paths:
283
+ song_name, lab_file_path, mp3_file_path, _ = path
284
+ if not song_name in valid_song_names:
285
+ continue
286
+ try:
287
+ n_timestep = config.model['timestep']
288
+ feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)
289
+ feature = feature.T
290
+ feature = (feature - mean) / std
291
+ time_unit = feature_per_second
292
+
293
+ num_pad = n_timestep - (feature.shape[0] % n_timestep)
294
+ feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
295
+ num_instance = feature.shape[0] // n_timestep
296
+
297
+ start_time = 0.0
298
+ lines = []
299
+ with torch.no_grad():
300
+ model.eval()
301
+ feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)
302
+ for t in range(num_instance):
303
+ if (model_type == 'cnn') or (model_type == 'crnn') or (model_type == 'btc'):
304
+ logits = pre_model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
305
+ prediction, _ = model(logits, torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
306
+ else:
307
+ raise NotImplementedError
308
+ for i in range(n_timestep):
309
+ if t == 0 and i == 0:
310
+ prev_chord = prediction[i].item()
311
+ continue
312
+ if prediction[i].item() != prev_chord:
313
+ lines.append(
314
+ '%.6f %.6f %s\n' % (
315
+ start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))
316
+ start_time = time_unit * (n_timestep * t + i)
317
+ prev_chord = prediction[i].item()
318
+ if t == num_instance - 1 and i + num_pad == n_timestep:
319
+ if start_time != time_unit * (n_timestep * t + i):
320
+ lines.append(
321
+ '%.6f %.6f %s\n' % (
322
+ start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))
323
+ break
324
+ pid = os.getpid()
325
+ tmp_path = 'tmp_' + str(pid) + '.lab'
326
+ with open(tmp_path, 'w') as f:
327
+ for line in lines:
328
+ f.write(line)
329
+
330
+ root_majmin = ['root', 'majmin']
331
+ for m in root_majmin:
332
+ metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))
333
+ song_length_list.append(song_length_second)
334
+ if verbose:
335
+ for m in root_majmin:
336
+ print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))
337
+ except:
338
+ print('song name %s\' lab file error' % song_name)
339
+
340
+ tmp = song_length_list / np.sum(song_length_list)
341
+ for m in root_majmin:
342
+ metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))
343
+
344
+ return metrics_.score_list_dict, song_length_list, metrics_.average_score
345
+
346
+
347
+ def large_voca_score_calculation(valid_dataset, config, mean, std, device, model, model_type, verbose=False):
348
+ idx2voca = idx2voca_chord()
349
+ valid_song_names = valid_dataset.song_names
350
+ paths = valid_dataset.preprocessor.get_all_files()
351
+
352
+ metrics_ = metrics()
353
+ song_length_list = list()
354
+ for path in paths:
355
+ song_name, lab_file_path, mp3_file_path, _ = path
356
+ if not song_name in valid_song_names:
357
+ continue
358
+ try:
359
+ n_timestep = config.model['timestep']
360
+ feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)
361
+ feature = feature.T
362
+ feature = (feature - mean) / std
363
+ time_unit = feature_per_second
364
+
365
+ num_pad = n_timestep - (feature.shape[0] % n_timestep)
366
+ feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
367
+ num_instance = feature.shape[0] // n_timestep
368
+
369
+ start_time = 0.0
370
+ lines = []
371
+ with torch.no_grad():
372
+ model.eval()
373
+ feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)
374
+ for t in range(num_instance):
375
+ if model_type == 'btc':
376
+ encoder_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])
377
+ prediction, _ = model.output_layer(encoder_output)
378
+ prediction = prediction.squeeze()
379
+ elif model_type == 'cnn' or model_type =='crnn':
380
+ prediction, _, _, _ = model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
381
+ for i in range(n_timestep):
382
+ if t == 0 and i == 0:
383
+ prev_chord = prediction[i].item()
384
+ continue
385
+ if prediction[i].item() != prev_chord:
386
+ lines.append(
387
+ '%.6f %.6f %s\n' % (
388
+ start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))
389
+ start_time = time_unit * (n_timestep * t + i)
390
+ prev_chord = prediction[i].item()
391
+ if t == num_instance - 1 and i + num_pad == n_timestep:
392
+ if start_time != time_unit * (n_timestep * t + i):
393
+ lines.append(
394
+ '%.6f %.6f %s\n' % (
395
+ start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))
396
+ break
397
+ pid = os.getpid()
398
+ tmp_path = 'tmp_' + str(pid) + '.lab'
399
+ with open(tmp_path, 'w') as f:
400
+ for line in lines:
401
+ f.write(line)
402
+
403
+ for m in metrics_.score_metrics:
404
+ metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))
405
+ song_length_list.append(song_length_second)
406
+ if verbose:
407
+ for m in metrics_.score_metrics:
408
+ print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))
409
+ except:
410
+ print('song name %s\' lab file error' % song_name)
411
+
412
+ tmp = song_length_list / np.sum(song_length_list)
413
+ for m in metrics_.score_metrics:
414
+ metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))
415
+
416
+ return metrics_.score_list_dict, song_length_list, metrics_.average_score
417
+
418
+ def large_voca_score_calculation_crf(valid_dataset, config, mean, std, device, pre_model, model, model_type, verbose=False):
419
+ idx2voca = idx2voca_chord()
420
+ valid_song_names = valid_dataset.song_names
421
+ paths = valid_dataset.preprocessor.get_all_files()
422
+
423
+ metrics_ = metrics()
424
+ song_length_list = list()
425
+ for path in paths:
426
+ song_name, lab_file_path, mp3_file_path, _ = path
427
+ if not song_name in valid_song_names:
428
+ continue
429
+ try:
430
+ n_timestep = config.model['timestep']
431
+ feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)
432
+ feature = feature.T
433
+ feature = (feature - mean) / std
434
+ time_unit = feature_per_second
435
+
436
+ num_pad = n_timestep - (feature.shape[0] % n_timestep)
437
+ feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
438
+ num_instance = feature.shape[0] // n_timestep
439
+
440
+ start_time = 0.0
441
+ lines = []
442
+ with torch.no_grad():
443
+ model.eval()
444
+ feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)
445
+ for t in range(num_instance):
446
+ if (model_type == 'cnn') or (model_type == 'crnn') or (model_type == 'btc'):
447
+ logits = pre_model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
448
+ prediction, _ = model(logits, torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
449
+ else:
450
+ raise NotImplementedError
451
+ for i in range(n_timestep):
452
+ if t == 0 and i == 0:
453
+ prev_chord = prediction[i].item()
454
+ continue
455
+ if prediction[i].item() != prev_chord:
456
+ lines.append(
457
+ '%.6f %.6f %s\n' % (
458
+ start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))
459
+ start_time = time_unit * (n_timestep * t + i)
460
+ prev_chord = prediction[i].item()
461
+ if t == num_instance - 1 and i + num_pad == n_timestep:
462
+ if start_time != time_unit * (n_timestep * t + i):
463
+ lines.append(
464
+ '%.6f %.6f %s\n' % (
465
+ start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))
466
+ break
467
+ pid = os.getpid()
468
+ tmp_path = 'tmp_' + str(pid) + '.lab'
469
+ with open(tmp_path, 'w') as f:
470
+ for line in lines:
471
+ f.write(line)
472
+
473
+ for m in metrics_.score_metrics:
474
+ metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))
475
+ song_length_list.append(song_length_second)
476
+ if verbose:
477
+ for m in metrics_.score_metrics:
478
+ print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))
479
+ except:
480
+ print('song name %s\' lab file error' % song_name)
481
+
482
+ tmp = song_length_list / np.sum(song_length_list)
483
+ for m in metrics_.score_metrics:
484
+ metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))
485
+
486
+ return metrics_.score_list_dict, song_length_list, metrics_.average_score
utils/preprocess.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import librosa
3
+ from utils.chords import Chords
4
+ import re
5
+ from enum import Enum
6
+ import pyrubberband as pyrb
7
+ import torch
8
+ import math
9
+
10
+ class FeatureTypes(Enum):
11
+ cqt = 'cqt'
12
+
13
+ class Preprocess():
14
+ def __init__(self, config, feature_to_use, dataset_names, root_dir):
15
+ self.config = config
16
+ self.dataset_names = dataset_names
17
+ self.root_path = root_dir + '/'
18
+
19
+ self.time_interval = config.feature["hop_length"]/config.mp3["song_hz"]
20
+ self.no_of_chord_datapoints_per_sequence = math.ceil(config.mp3['inst_len'] / self.time_interval)
21
+ self.Chord_class = Chords()
22
+
23
+ # isophonic
24
+ self.isophonic_directory = self.root_path + 'isophonic/'
25
+
26
+ # uspop
27
+ self.uspop_directory = self.root_path + 'uspop/'
28
+ self.uspop_audio_path = 'audio/'
29
+ self.uspop_lab_path = 'annotations/uspopLabels/'
30
+ self.uspop_index_path = 'annotations/uspopLabels.txt'
31
+
32
+ # robbie williams
33
+ self.robbie_williams_directory = self.root_path + 'robbiewilliams/'
34
+ self.robbie_williams_audio_path = 'audio/'
35
+ self.robbie_williams_lab_path = 'chords/'
36
+
37
+ self.feature_name = feature_to_use
38
+ self.is_cut_last_chord = False
39
+
40
+ def find_mp3_path(self, dirpath, word):
41
+ for filename in os.listdir(dirpath):
42
+ last_dir = dirpath.split("/")[-2]
43
+ if ".mp3" in filename:
44
+ tmp = filename.replace(".mp3", "")
45
+ tmp = tmp.replace(last_dir, "")
46
+ filename_lower = tmp.lower()
47
+ filename_lower = " ".join(re.findall("[a-zA-Z]+", filename_lower))
48
+ if word.lower().replace(" ", "") in filename_lower.replace(" ", ""):
49
+ return filename
50
+
51
+ def find_mp3_path_robbiewilliams(self, dirpath, word):
52
+ for filename in os.listdir(dirpath):
53
+ if ".mp3" in filename:
54
+ tmp = filename.replace(".mp3", "")
55
+ filename_lower = tmp.lower()
56
+ filename_lower = filename_lower.replace("robbie williams", "")
57
+ filename_lower = " ".join(re.findall("[a-zA-Z]+", filename_lower))
58
+ filename_lower = self.song_pre(filename_lower)
59
+ if self.song_pre(word.lower()).replace(" ", "") in filename_lower.replace(" ", ""):
60
+ return filename
61
+
62
+ def get_all_files(self):
63
+ res_list = []
64
+
65
+ # isophonic
66
+ if "isophonic" in self.dataset_names:
67
+ for dirpath, dirnames, filenames in os.walk(self.isophonic_directory):
68
+ if not dirnames:
69
+ for filename in filenames:
70
+ if ".lab" in filename:
71
+ tmp = filename.replace(".lab", "")
72
+ song_name = " ".join(re.findall("[a-zA-Z]+", tmp)).replace("CD", "")
73
+ mp3_path = self.find_mp3_path(dirpath, song_name)
74
+ res_list.append([song_name, os.path.join(dirpath, filename), os.path.join(dirpath, mp3_path),
75
+ os.path.join(self.root_path, "result", "isophonic")])
76
+
77
+ # uspop
78
+ if "uspop" in self.dataset_names:
79
+ with open(os.path.join(self.uspop_directory, self.uspop_index_path)) as f:
80
+ uspop_lab_list = f.readlines()
81
+ uspop_lab_list = [x.strip() for x in uspop_lab_list]
82
+
83
+ for lab_path in uspop_lab_list:
84
+ spl = lab_path.split('/')
85
+ lab_artist = self.uspop_pre(spl[2])
86
+ lab_title = self.uspop_pre(spl[4][3:-4])
87
+ lab_path = lab_path.replace('./uspopLabels/', '')
88
+ lab_path = os.path.join(self.uspop_directory, self.uspop_lab_path, lab_path)
89
+
90
+ for filename in os.listdir(os.path.join(self.uspop_directory, self.uspop_audio_path)):
91
+ if not '.csv' in filename:
92
+ spl = filename.split('-')
93
+ mp3_artist = self.uspop_pre(spl[0])
94
+ mp3_title = self.uspop_pre(spl[1][:-4])
95
+
96
+ if lab_artist == mp3_artist and lab_title == mp3_title:
97
+ res_list.append([mp3_artist + mp3_title, lab_path,
98
+ os.path.join(self.uspop_directory, self.uspop_audio_path, filename),
99
+ os.path.join(self.root_path, "result", "uspop")])
100
+ break
101
+
102
+ # robbie williams
103
+ if "robbiewilliams" in self.dataset_names:
104
+ for dirpath, dirnames, filenames in os.walk(self.robbie_williams_directory):
105
+ if not dirnames:
106
+ for filename in filenames:
107
+ if ".txt" in filename and (not 'README' in filename):
108
+ tmp = filename.replace(".txt", "")
109
+ song_name = " ".join(re.findall("[a-zA-Z]+", tmp)).replace("GTChords", "")
110
+ mp3_dir = dirpath.replace("chords", "audio")
111
+ mp3_path = self.find_mp3_path_robbiewilliams(mp3_dir, song_name)
112
+ res_list.append([song_name, os.path.join(dirpath, filename), os.path.join(mp3_dir, mp3_path),
113
+ os.path.join(self.root_path, "result", "robbiewilliams")])
114
+ return res_list
115
+
116
+ def uspop_pre(self, text):
117
+ text = text.lower()
118
+ text = text.replace('_', '')
119
+ text = text.replace(' ', '')
120
+ text = " ".join(re.findall("[a-zA-Z]+", text))
121
+ return text
122
+
123
+ def song_pre(self, text):
124
+ to_remove = ["'", '`', '(', ')', ' ', '&', 'and', 'And']
125
+
126
+ for remove in to_remove:
127
+ text = text.replace(remove, '')
128
+
129
+ return text
130
+
131
+ def config_to_folder(self):
132
+ mp3_config = self.config.mp3
133
+ feature_config = self.config.feature
134
+ mp3_string = "%d_%.1f_%.1f" % \
135
+ (mp3_config['song_hz'], mp3_config['inst_len'],
136
+ mp3_config['skip_interval'])
137
+ feature_string = "%s_%d_%d_%d" % \
138
+ (self.feature_name.value, feature_config['n_bins'], feature_config['bins_per_octave'], feature_config['hop_length'])
139
+
140
+ return mp3_config, feature_config, mp3_string, feature_string
141
+
142
+ def generate_labels_features_new(self, all_list):
143
+ pid = os.getpid()
144
+ mp3_config, feature_config, mp3_str, feature_str = self.config_to_folder()
145
+
146
+ i = 0 # number of songs
147
+ j = 0 # number of impossible songs
148
+ k = 0 # number of tried songs
149
+ total = 0 # number of generated instances
150
+
151
+ stretch_factors = [1.0]
152
+ shift_factors = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6]
153
+
154
+ loop_broken = False
155
+ for song_name, lab_path, mp3_path, save_path in all_list:
156
+
157
+ # different song initialization
158
+ if loop_broken:
159
+ loop_broken = False
160
+
161
+ i += 1
162
+ print(pid, "generating features from ...", os.path.join(mp3_path))
163
+ if i % 10 == 0:
164
+ print(i, ' th song')
165
+
166
+ original_wav, sr = librosa.load(os.path.join(mp3_path), sr=mp3_config['song_hz'])
167
+
168
+ # make result path if not exists
169
+ # save_path, mp3_string, feature_string, song_name, aug.pt
170
+ result_path = os.path.join(save_path, mp3_str, feature_str, song_name.strip())
171
+ if not os.path.exists(result_path):
172
+ os.makedirs(result_path)
173
+
174
+ # calculate result
175
+ for stretch_factor in stretch_factors:
176
+ if loop_broken:
177
+ loop_broken = False
178
+ break
179
+
180
+ for shift_factor in shift_factors:
181
+ # for filename
182
+ idx = 0
183
+
184
+ chord_info = self.Chord_class.get_converted_chord(os.path.join(lab_path))
185
+
186
+ k += 1
187
+ # stretch original sound and chord info
188
+ x = pyrb.time_stretch(original_wav, sr, stretch_factor)
189
+ x = pyrb.pitch_shift(x, sr, shift_factor)
190
+ audio_length = x.shape[0]
191
+ chord_info['start'] = chord_info['start'] * 1/stretch_factor
192
+ chord_info['end'] = chord_info['end'] * 1/stretch_factor
193
+
194
+ last_sec = chord_info.iloc[-1]['end']
195
+ last_sec_hz = int(last_sec * mp3_config['song_hz'])
196
+
197
+ if audio_length + mp3_config['skip_interval'] < last_sec_hz:
198
+ print('loaded song is too short :', song_name)
199
+ loop_broken = True
200
+ j += 1
201
+ break
202
+ elif audio_length > last_sec_hz:
203
+ x = x[:last_sec_hz]
204
+
205
+ origin_length = last_sec_hz
206
+ origin_length_in_sec = origin_length / mp3_config['song_hz']
207
+
208
+ current_start_second = 0
209
+
210
+ # get chord list between current_start_second and current+song_length
211
+ while current_start_second + mp3_config['inst_len'] < origin_length_in_sec:
212
+ inst_start_sec = current_start_second
213
+ curSec = current_start_second
214
+
215
+ chord_list = []
216
+ # extract chord per 1/self.time_interval
217
+ while curSec < inst_start_sec + mp3_config['inst_len']:
218
+ try:
219
+ available_chords = chord_info.loc[(chord_info['start'] <= curSec) & (
220
+ chord_info['end'] > curSec + self.time_interval)].copy()
221
+ if len(available_chords) == 0:
222
+ available_chords = chord_info.loc[((chord_info['start'] >= curSec) & (
223
+ chord_info['start'] <= curSec + self.time_interval)) | (
224
+ (chord_info['end'] >= curSec) & (
225
+ chord_info['end'] <= curSec + self.time_interval))].copy()
226
+ if len(available_chords) == 1:
227
+ chord = available_chords['chord_id'].iloc[0]
228
+ elif len(available_chords) > 1:
229
+ max_starts = available_chords.apply(lambda row: max(row['start'], curSec),
230
+ axis=1)
231
+ available_chords['max_start'] = max_starts
232
+ min_ends = available_chords.apply(
233
+ lambda row: min(row.end, curSec + self.time_interval), axis=1)
234
+ available_chords['min_end'] = min_ends
235
+ chords_lengths = available_chords['min_end'] - available_chords['max_start']
236
+ available_chords['chord_length'] = chords_lengths
237
+ chord = available_chords.ix[available_chords['chord_length'].idxmax()]['chord_id']
238
+ else:
239
+ chord = 24
240
+ except Exception as e:
241
+ chord = 24
242
+ print(e)
243
+ print(pid, "no chord")
244
+ raise RuntimeError()
245
+ finally:
246
+ # convert chord by shift factor
247
+ if chord != 24:
248
+ chord += shift_factor * 2
249
+ chord = chord % 24
250
+
251
+ chord_list.append(chord)
252
+ curSec += self.time_interval
253
+
254
+ if len(chord_list) == self.no_of_chord_datapoints_per_sequence:
255
+ try:
256
+ sequence_start_time = current_start_second
257
+ sequence_end_time = current_start_second + mp3_config['inst_len']
258
+
259
+ start_index = int(sequence_start_time * mp3_config['song_hz'])
260
+ end_index = int(sequence_end_time * mp3_config['song_hz'])
261
+
262
+ song_seq = x[start_index:end_index]
263
+
264
+ etc = '%.1f_%.1f' % (
265
+ current_start_second, current_start_second + mp3_config['inst_len'])
266
+ aug = '%.2f_%i' % (stretch_factor, shift_factor)
267
+
268
+ if self.feature_name == FeatureTypes.cqt:
269
+ # print(pid, "make feature")
270
+ feature = librosa.cqt(song_seq, sr=sr, n_bins=feature_config['n_bins'],
271
+ bins_per_octave=feature_config['bins_per_octave'],
272
+ hop_length=feature_config['hop_length'])
273
+ else:
274
+ raise NotImplementedError
275
+
276
+ if feature.shape[1] > self.no_of_chord_datapoints_per_sequence:
277
+ feature = feature[:, :self.no_of_chord_datapoints_per_sequence]
278
+
279
+ if feature.shape[1] != self.no_of_chord_datapoints_per_sequence:
280
+ print('loaded features length is too short :', song_name)
281
+ loop_broken = True
282
+ j += 1
283
+ break
284
+
285
+ result = {
286
+ 'feature': feature,
287
+ 'chord': chord_list,
288
+ 'etc': etc
289
+ }
290
+
291
+ # save_path, mp3_string, feature_string, song_name, aug.pt
292
+ filename = aug + "_" + str(idx) + ".pt"
293
+ torch.save(result, os.path.join(result_path, filename))
294
+ idx += 1
295
+ total += 1
296
+ except Exception as e:
297
+ print(e)
298
+ print(pid, "feature error")
299
+ raise RuntimeError()
300
+ else:
301
+ print("invalid number of chord datapoints in sequence :", len(chord_list))
302
+ current_start_second += mp3_config['skip_interval']
303
+ print(pid, "total instances: %d" % total)
304
+
305
+ def generate_labels_features_voca(self, all_list):
306
+ pid = os.getpid()
307
+ mp3_config, feature_config, mp3_str, feature_str = self.config_to_folder()
308
+
309
+ i = 0 # number of songs
310
+ j = 0 # number of impossible songs
311
+ k = 0 # number of tried songs
312
+ total = 0 # number of generated instances
313
+ stretch_factors = [1.0]
314
+ shift_factors = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6]
315
+
316
+ loop_broken = False
317
+ for song_name, lab_path, mp3_path, save_path in all_list:
318
+ save_path = save_path + '_voca'
319
+
320
+ # different song initialization
321
+ if loop_broken:
322
+ loop_broken = False
323
+
324
+ i += 1
325
+ print(pid, "generating features from ...", os.path.join(mp3_path))
326
+ if i % 10 == 0:
327
+ print(i, ' th song')
328
+
329
+ original_wav, sr = librosa.load(os.path.join(mp3_path), sr=mp3_config['song_hz'])
330
+
331
+ # save_path, mp3_string, feature_string, song_name, aug.pt
332
+ result_path = os.path.join(save_path, mp3_str, feature_str, song_name.strip())
333
+ if not os.path.exists(result_path):
334
+ os.makedirs(result_path)
335
+
336
+ # calculate result
337
+ for stretch_factor in stretch_factors:
338
+ if loop_broken:
339
+ loop_broken = False
340
+ break
341
+
342
+ for shift_factor in shift_factors:
343
+ # for filename
344
+ idx = 0
345
+
346
+ try:
347
+ chord_info = self.Chord_class.get_converted_chord_voca(os.path.join(lab_path))
348
+ except Exception as e:
349
+ print(e)
350
+ print(pid, " chord lab file error : %s" % song_name)
351
+ loop_broken = True
352
+ j += 1
353
+ break
354
+
355
+ k += 1
356
+ # stretch original sound and chord info
357
+ x = pyrb.time_stretch(original_wav, sr, stretch_factor)
358
+ x = pyrb.pitch_shift(x, sr, shift_factor)
359
+ audio_length = x.shape[0]
360
+ chord_info['start'] = chord_info['start'] * 1/stretch_factor
361
+ chord_info['end'] = chord_info['end'] * 1/stretch_factor
362
+
363
+ last_sec = chord_info.iloc[-1]['end']
364
+ last_sec_hz = int(last_sec * mp3_config['song_hz'])
365
+
366
+ if audio_length + mp3_config['skip_interval'] < last_sec_hz:
367
+ print('loaded song is too short :', song_name)
368
+ loop_broken = True
369
+ j += 1
370
+ break
371
+ elif audio_length > last_sec_hz:
372
+ x = x[:last_sec_hz]
373
+
374
+ origin_length = last_sec_hz
375
+ origin_length_in_sec = origin_length / mp3_config['song_hz']
376
+
377
+ current_start_second = 0
378
+
379
+ # get chord list between current_start_second and current+song_length
380
+ while current_start_second + mp3_config['inst_len'] < origin_length_in_sec:
381
+ inst_start_sec = current_start_second
382
+ curSec = current_start_second
383
+
384
+ chord_list = []
385
+ # extract chord per 1/self.time_interval
386
+ while curSec < inst_start_sec + mp3_config['inst_len']:
387
+ try:
388
+ available_chords = chord_info.loc[(chord_info['start'] <= curSec) & (chord_info['end'] > curSec + self.time_interval)].copy()
389
+ if len(available_chords) == 0:
390
+ available_chords = chord_info.loc[((chord_info['start'] >= curSec) & (chord_info['start'] <= curSec + self.time_interval)) | ((chord_info['end'] >= curSec) & (chord_info['end'] <= curSec + self.time_interval))].copy()
391
+
392
+ if len(available_chords) == 1:
393
+ chord = available_chords['chord_id'].iloc[0]
394
+ elif len(available_chords) > 1:
395
+ max_starts = available_chords.apply(lambda row: max(row['start'], curSec),axis=1)
396
+ available_chords['max_start'] = max_starts
397
+ min_ends = available_chords.apply(lambda row: min(row.end, curSec + self.time_interval), axis=1)
398
+ available_chords['min_end'] = min_ends
399
+ chords_lengths = available_chords['min_end'] - available_chords['max_start']
400
+ available_chords['chord_length'] = chords_lengths
401
+ chord = available_chords.ix[available_chords['chord_length'].idxmax()]['chord_id']
402
+ else:
403
+ chord = 169
404
+ except Exception as e:
405
+ chord = 169
406
+ print(e)
407
+ print(pid, "no chord")
408
+ raise RuntimeError()
409
+ finally:
410
+ # convert chord by shift factor
411
+ if chord != 169 and chord != 168:
412
+ chord += shift_factor * 14
413
+ chord = chord % 168
414
+
415
+ chord_list.append(chord)
416
+ curSec += self.time_interval
417
+
418
+ if len(chord_list) == self.no_of_chord_datapoints_per_sequence:
419
+ try:
420
+ sequence_start_time = current_start_second
421
+ sequence_end_time = current_start_second + mp3_config['inst_len']
422
+
423
+ start_index = int(sequence_start_time * mp3_config['song_hz'])
424
+ end_index = int(sequence_end_time * mp3_config['song_hz'])
425
+
426
+ song_seq = x[start_index:end_index]
427
+
428
+ etc = '%.1f_%.1f' % (
429
+ current_start_second, current_start_second + mp3_config['inst_len'])
430
+ aug = '%.2f_%i' % (stretch_factor, shift_factor)
431
+
432
+ if self.feature_name == FeatureTypes.cqt:
433
+ feature = librosa.cqt(song_seq, sr=sr, n_bins=feature_config['n_bins'],
434
+ bins_per_octave=feature_config['bins_per_octave'],
435
+ hop_length=feature_config['hop_length'])
436
+ else:
437
+ raise NotImplementedError
438
+
439
+ if feature.shape[1] > self.no_of_chord_datapoints_per_sequence:
440
+ feature = feature[:, :self.no_of_chord_datapoints_per_sequence]
441
+
442
+ if feature.shape[1] != self.no_of_chord_datapoints_per_sequence:
443
+ print('loaded features length is too short :', song_name)
444
+ loop_broken = True
445
+ j += 1
446
+ break
447
+
448
+ result = {
449
+ 'feature': feature,
450
+ 'chord': chord_list,
451
+ 'etc': etc
452
+ }
453
+
454
+ # save_path, mp3_string, feature_string, song_name, aug.pt
455
+ filename = aug + "_" + str(idx) + ".pt"
456
+ torch.save(result, os.path.join(result_path, filename))
457
+ idx += 1
458
+ total += 1
459
+ except Exception as e:
460
+ print(e)
461
+ print(pid, "feature error")
462
+ raise RuntimeError()
463
+ else:
464
+ print("invalid number of chord datapoints in sequence :", len(chord_list))
465
+ current_start_second += mp3_config['skip_interval']
466
+ print(pid, "total instances: %d" % total)
utils/pytorch_utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+ import math
6
+ from utils import logger
7
+
8
+ use_cuda = torch.cuda.is_available()
9
+
10
+
11
+ # optimization
12
+ # reference: http://pytorch.org/docs/master/_modules/torch/optim/lr_scheduler.html#ReduceLROnPlateau
13
+ def adjusting_learning_rate(optimizer, factor=.5, min_lr=0.00001):
14
+ for i, param_group in enumerate(optimizer.param_groups):
15
+ old_lr = float(param_group['lr'])
16
+ new_lr = max(old_lr * factor, min_lr)
17
+ param_group['lr'] = new_lr
18
+ logger.info('adjusting learning rate from %.6f to %.6f' % (old_lr, new_lr))
19
+
20
+
21
+ # model save and loading
22
+ def load_model(asset_path, model, optimizer, restore_epoch=0):
23
+ if os.path.isfile(os.path.join(asset_path, 'model', 'checkpoint_%d.pth.tar' % restore_epoch), map_location=lambda storage, loc: storage):
24
+ checkpoint = torch.load(os.path.join(asset_path, 'model', 'checkpoint_%d.pth.tar' % restore_epoch))
25
+ model.load_state_dict(checkpoint['model'])
26
+ optimizer.load_state_dict(checkpoint['optimizer'])
27
+ current_step = checkpoint['current_step']
28
+ logger.info("restore model with %d epoch" % restore_epoch)
29
+ else:
30
+ logger.info("no checkpoint with %d epoch" % restore_epoch)
31
+ current_step = 0
32
+
33
+ return model, optimizer, current_step
utils/tf_logger.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ import scipy.misc
4
+
5
+ try:
6
+ from StringIO import StringIO # Python 2.7
7
+ except ImportError:
8
+ from io import BytesIO # Python 3.x
9
+
10
+
11
+ class TF_Logger(object):
12
+ def __init__(self, log_dir):
13
+ """Create a summary writer logging to log_dir."""
14
+ self.writer = tf.summary.FileWriter(log_dir)
15
+
16
+ def scalar_summary(self, tag, value, step):
17
+ """Log a scalar variable."""
18
+ summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
19
+ self.writer.add_summary(summary, step)
20
+
21
+ def image_summary(self, tag, images, step):
22
+ """Log a list of images."""
23
+
24
+ img_summaries = []
25
+ for i, img in enumerate(images):
26
+ # Write the image to a string
27
+ try:
28
+ s = StringIO()
29
+ except:
30
+ s = BytesIO()
31
+ scipy.misc.toimage(img).save(s, format="png")
32
+
33
+ # Create an Image object
34
+ img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
35
+ height=img.shape[0],
36
+ width=img.shape[1])
37
+ # Create a Summary value
38
+ img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
39
+
40
+ # Create and write Summary
41
+ summary = tf.Summary(value=img_summaries)
42
+ self.writer.add_summary(summary, step)
43
+
44
+ def histo_summary(self, tag, values, step, bins=1000):
45
+ """Log a histogram of the tensor of values."""
46
+
47
+ # Create a histogram using numpy
48
+ counts, bin_edges = np.histogram(values, bins=bins)
49
+
50
+ # Fill the fields of the histogram proto
51
+ hist = tf.HistogramProto()
52
+ hist.min = float(np.min(values))
53
+ hist.max = float(np.max(values))
54
+ hist.num = int(np.prod(values.shape))
55
+ hist.sum = float(np.sum(values))
56
+ hist.sum_squares = float(np.sum(values ** 2))
57
+
58
+ # Drop the start of the first bin
59
+ bin_edges = bin_edges[1:]
60
+
61
+ # Add bin edges and counts
62
+ for edge in bin_edges:
63
+ hist.bucket_limit.append(edge)
64
+ for c in counts:
65
+ hist.bucket.append(c)
66
+
67
+ # Create and write Summary
68
+ summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
69
+ self.writer.add_summary(summary, step)
70
+ self.writer.flush()
utils/transformer_modules.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import math
9
+
10
+ def _gen_bias_mask(max_length):
11
+ """
12
+ Generates bias values (-Inf) to mask future timesteps during attention
13
+ """
14
+ np_mask = np.triu(np.full([max_length, max_length], -np.inf), 1)
15
+ torch_mask = torch.from_numpy(np_mask).type(torch.FloatTensor)
16
+ return torch_mask.unsqueeze(0).unsqueeze(1)
17
+
18
+ def _gen_timing_signal(length, channels, min_timescale=1.0, max_timescale=1.0e4):
19
+ """
20
+ Generates a [1, length, channels] timing signal consisting of sinusoids
21
+ Adapted from:
22
+ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py
23
+ """
24
+ position = np.arange(length)
25
+ num_timescales = channels // 2
26
+ log_timescale_increment = (
27
+ math.log(float(max_timescale) / float(min_timescale)) /
28
+ (float(num_timescales) - 1))
29
+ inv_timescales = min_timescale * np.exp(
30
+ np.arange(num_timescales).astype(np.float64) * -log_timescale_increment)
31
+ scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0)
32
+
33
+ signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
34
+ signal = np.pad(signal, [[0, 0], [0, channels % 2]],
35
+ 'constant', constant_values=[0.0, 0.0])
36
+ signal = signal.reshape([1, length, channels])
37
+
38
+ return torch.from_numpy(signal).type(torch.FloatTensor)
39
+
40
+ class LayerNorm(nn.Module):
41
+ # Borrowed from jekbradbury
42
+ # https://github.com/pytorch/pytorch/issues/1959
43
+ def __init__(self, features, eps=1e-6):
44
+ super(LayerNorm, self).__init__()
45
+ self.gamma = nn.Parameter(torch.ones(features))
46
+ self.beta = nn.Parameter(torch.zeros(features))
47
+ self.eps = eps
48
+
49
+ def forward(self, x):
50
+ mean = x.mean(-1, keepdim=True)
51
+ std = x.std(-1, keepdim=True)
52
+ return self.gamma * (x - mean) / (std + self.eps) + self.beta
53
+
54
+ class OutputLayer(nn.Module):
55
+ """
56
+ Abstract base class for output layer.
57
+ Handles projection to output labels
58
+ """
59
+ def __init__(self, hidden_size, output_size, probs_out=False):
60
+ super(OutputLayer, self).__init__()
61
+ self.output_size = output_size
62
+ self.output_projection = nn.Linear(hidden_size, output_size)
63
+ self.probs_out = probs_out
64
+ self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=int(hidden_size/2), batch_first=True, bidirectional=True)
65
+ self.hidden_size = hidden_size
66
+
67
+ def loss(self, hidden, labels):
68
+ raise NotImplementedError('Must implement {}.loss'.format(self.__class__.__name__))
69
+
70
+ class SoftmaxOutputLayer(OutputLayer):
71
+ """
72
+ Implements a softmax based output layer
73
+ """
74
+ def forward(self, hidden):
75
+ logits = self.output_projection(hidden)
76
+ probs = F.softmax(logits, -1)
77
+ # _, predictions = torch.max(probs, dim=-1)
78
+ topk, indices = torch.topk(probs, 2)
79
+ predictions = indices[:,:,0]
80
+ second = indices[:,:,1]
81
+ if self.probs_out is True:
82
+ return logits
83
+ # return probs
84
+ return predictions, second
85
+
86
+ def loss(self, hidden, labels):
87
+ logits = self.output_projection(hidden)
88
+ log_probs = F.log_softmax(logits, -1)
89
+ return F.nll_loss(log_probs.view(-1, self.output_size), labels.view(-1))
90
+
91
+ class MultiHeadAttention(nn.Module):
92
+ """
93
+ Multi-head attention as per https://arxiv.org/pdf/1706.03762.pdf
94
+ Refer Figure 2
95
+ """
96
+
97
+ def __init__(self, input_depth, total_key_depth, total_value_depth, output_depth,
98
+ num_heads, bias_mask=None, dropout=0.0, attention_map=False):
99
+ """
100
+ Parameters:
101
+ input_depth: Size of last dimension of input
102
+ total_key_depth: Size of last dimension of keys. Must be divisible by num_head
103
+ total_value_depth: Size of last dimension of values. Must be divisible by num_head
104
+ output_depth: Size last dimension of the final output
105
+ num_heads: Number of attention heads
106
+ bias_mask: Masking tensor to prevent connections to future elements
107
+ dropout: Dropout probability (Should be non-zero only during training)
108
+ """
109
+ super(MultiHeadAttention, self).__init__()
110
+ # Checks borrowed from
111
+ # https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py
112
+ if total_key_depth % num_heads != 0:
113
+ raise ValueError("Key depth (%d) must be divisible by the number of "
114
+ "attention heads (%d)." % (total_key_depth, num_heads))
115
+ if total_value_depth % num_heads != 0:
116
+ raise ValueError("Value depth (%d) must be divisible by the number of "
117
+ "attention heads (%d)." % (total_value_depth, num_heads))
118
+
119
+ self.attention_map = attention_map
120
+
121
+ self.num_heads = num_heads
122
+ self.query_scale = (total_key_depth // num_heads) ** -0.5
123
+ self.bias_mask = bias_mask
124
+
125
+ # Key and query depth will be same
126
+ self.query_linear = nn.Linear(input_depth, total_key_depth, bias=False)
127
+ self.key_linear = nn.Linear(input_depth, total_key_depth, bias=False)
128
+ self.value_linear = nn.Linear(input_depth, total_value_depth, bias=False)
129
+ self.output_linear = nn.Linear(total_value_depth, output_depth, bias=False)
130
+
131
+ self.dropout = nn.Dropout(dropout)
132
+
133
+ def _split_heads(self, x):
134
+ """
135
+ Split x such to add an extra num_heads dimension
136
+ Input:
137
+ x: a Tensor with shape [batch_size, seq_length, depth]
138
+ Returns:
139
+ A Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads]
140
+ """
141
+ if len(x.shape) != 3:
142
+ raise ValueError("x must have rank 3")
143
+ shape = x.shape
144
+ return x.view(shape[0], shape[1], self.num_heads, shape[2] // self.num_heads).permute(0, 2, 1, 3)
145
+
146
+ def _merge_heads(self, x):
147
+ """
148
+ Merge the extra num_heads into the last dimension
149
+ Input:
150
+ x: a Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads]
151
+ Returns:
152
+ A Tensor with shape [batch_size, seq_length, depth]
153
+ """
154
+ if len(x.shape) != 4:
155
+ raise ValueError("x must have rank 4")
156
+ shape = x.shape
157
+ return x.permute(0, 2, 1, 3).contiguous().view(shape[0], shape[2], shape[3] * self.num_heads)
158
+
159
+ def forward(self, queries, keys, values):
160
+
161
+ # Do a linear for each component
162
+ queries = self.query_linear(queries)
163
+ keys = self.key_linear(keys)
164
+ values = self.value_linear(values)
165
+
166
+ # Split into multiple heads
167
+ queries = self._split_heads(queries)
168
+ keys = self._split_heads(keys)
169
+ values = self._split_heads(values)
170
+
171
+ # Scale queries
172
+ queries *= self.query_scale
173
+
174
+ # Combine queries and keys
175
+ logits = torch.matmul(queries, keys.permute(0, 1, 3, 2))
176
+
177
+ # Add bias to mask future values
178
+ if self.bias_mask is not None:
179
+ logits += self.bias_mask[:, :, :logits.shape[-2], :logits.shape[-1]].type_as(logits.data)
180
+
181
+ # Convert to probabilites
182
+ weights = nn.functional.softmax(logits, dim=-1)
183
+
184
+ # Dropout
185
+ weights = self.dropout(weights)
186
+
187
+ # Combine with values to get context
188
+ contexts = torch.matmul(weights, values)
189
+
190
+ # Merge heads
191
+ contexts = self._merge_heads(contexts)
192
+ # contexts = torch.tanh(contexts)
193
+
194
+ # Linear to get output
195
+ outputs = self.output_linear(contexts)
196
+
197
+ if self.attention_map is True:
198
+ return outputs, weights
199
+
200
+ return outputs
201
+
202
+
203
+ class Conv(nn.Module):
204
+ """
205
+ Convenience class that does padding and convolution for inputs in the format
206
+ [batch_size, sequence length, hidden size]
207
+ """
208
+
209
+ def __init__(self, input_size, output_size, kernel_size, pad_type):
210
+ """
211
+ Parameters:
212
+ input_size: Input feature size
213
+ output_size: Output feature size
214
+ kernel_size: Kernel width
215
+ pad_type: left -> pad on the left side (to mask future data_loader),
216
+ both -> pad on both sides
217
+ """
218
+ super(Conv, self).__init__()
219
+ padding = (kernel_size - 1, 0) if pad_type == 'left' else (kernel_size // 2, (kernel_size - 1) // 2)
220
+ self.pad = nn.ConstantPad1d(padding, 0)
221
+ self.conv = nn.Conv1d(input_size, output_size, kernel_size=kernel_size, padding=0)
222
+
223
+ def forward(self, inputs):
224
+ inputs = self.pad(inputs.permute(0, 2, 1))
225
+ outputs = self.conv(inputs).permute(0, 2, 1)
226
+
227
+ return outputs
228
+
229
+
230
+ class PositionwiseFeedForward(nn.Module):
231
+ """
232
+ Does a Linear + RELU + Linear on each of the timesteps
233
+ """
234
+
235
+ def __init__(self, input_depth, filter_size, output_depth, layer_config='ll', padding='left', dropout=0.0):
236
+ """
237
+ Parameters:
238
+ input_depth: Size of last dimension of input
239
+ filter_size: Hidden size of the middle layer
240
+ output_depth: Size last dimension of the final output
241
+ layer_config: ll -> linear + ReLU + linear
242
+ cc -> conv + ReLU + conv etc.
243
+ padding: left -> pad on the left side (to mask future data_loader),
244
+ both -> pad on both sides
245
+ dropout: Dropout probability (Should be non-zero only during training)
246
+ """
247
+ super(PositionwiseFeedForward, self).__init__()
248
+
249
+ layers = []
250
+ sizes = ([(input_depth, filter_size)] +
251
+ [(filter_size, filter_size)] * (len(layer_config) - 2) +
252
+ [(filter_size, output_depth)])
253
+
254
+ for lc, s in zip(list(layer_config), sizes):
255
+ if lc == 'l':
256
+ layers.append(nn.Linear(*s))
257
+ elif lc == 'c':
258
+ layers.append(Conv(*s, kernel_size=3, pad_type=padding))
259
+ else:
260
+ raise ValueError("Unknown layer type {}".format(lc))
261
+
262
+ self.layers = nn.ModuleList(layers)
263
+ self.relu = nn.ReLU()
264
+ self.dropout = nn.Dropout(dropout)
265
+
266
+ def forward(self, inputs):
267
+ x = inputs
268
+ for i, layer in enumerate(self.layers):
269
+ x = layer(x)
270
+ if i < len(self.layers):
271
+ x = self.relu(x)
272
+ x = self.dropout(x)
273
+
274
+ return x