Spaces:
Sleeping
Sleeping
Upload 22 files
Browse files- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/__pycache__/btc_model.cpython-310.pyc +0 -0
- utils/__pycache__/constants.cpython-310.pyc +0 -0
- utils/__pycache__/custom_early_stopping.cpython-310.pyc +0 -0
- utils/__pycache__/hparams.cpython-310.pyc +0 -0
- utils/__pycache__/logger.cpython-310.pyc +0 -0
- utils/__pycache__/mert.cpython-310.pyc +0 -0
- utils/__pycache__/mir_eval_modules.cpython-310.pyc +0 -0
- utils/__pycache__/transformer_modules.cpython-310.pyc +0 -0
- utils/btc_model.py +198 -0
- utils/chords.py +542 -0
- utils/constants.py +60 -0
- utils/custom_early_stopping.py +93 -0
- utils/hparams.py +37 -0
- utils/logger.py +72 -0
- utils/mert.py +32 -0
- utils/mir_eval_modules.py +486 -0
- utils/preprocess.py +466 -0
- utils/pytorch_utils.py +33 -0
- utils/tf_logger.py +70 -0
- utils/transformer_modules.py +274 -0
utils/__init__.py
ADDED
File without changes
|
utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (137 Bytes). View file
|
|
utils/__pycache__/btc_model.cpython-310.pyc
ADDED
Binary file (5.19 kB). View file
|
|
utils/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (574 Bytes). View file
|
|
utils/__pycache__/custom_early_stopping.cpython-310.pyc
ADDED
Binary file (1.77 kB). View file
|
|
utils/__pycache__/hparams.cpython-310.pyc
ADDED
Binary file (1.69 kB). View file
|
|
utils/__pycache__/logger.cpython-310.pyc
ADDED
Binary file (1.87 kB). View file
|
|
utils/__pycache__/mert.cpython-310.pyc
ADDED
Binary file (1.56 kB). View file
|
|
utils/__pycache__/mir_eval_modules.cpython-310.pyc
ADDED
Binary file (12.8 kB). View file
|
|
utils/__pycache__/transformer_modules.cpython-310.pyc
ADDED
Binary file (9.98 kB). View file
|
|
utils/btc_model.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.transformer_modules import *
|
2 |
+
from utils.transformer_modules import _gen_timing_signal, _gen_bias_mask
|
3 |
+
from utils.hparams import HParams
|
4 |
+
|
5 |
+
use_cuda = torch.cuda.is_available()
|
6 |
+
|
7 |
+
class self_attention_block(nn.Module):
|
8 |
+
def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads,
|
9 |
+
bias_mask=None, layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0, attention_map=False):
|
10 |
+
super(self_attention_block, self).__init__()
|
11 |
+
|
12 |
+
self.attention_map = attention_map
|
13 |
+
self.multi_head_attention = MultiHeadAttention(hidden_size, total_key_depth, total_value_depth,hidden_size, num_heads, bias_mask, attention_dropout, attention_map)
|
14 |
+
self.positionwise_convolution = PositionwiseFeedForward(hidden_size, filter_size, hidden_size, layer_config='cc', padding='both', dropout=relu_dropout)
|
15 |
+
self.dropout = nn.Dropout(layer_dropout)
|
16 |
+
self.layer_norm_mha = LayerNorm(hidden_size)
|
17 |
+
self.layer_norm_ffn = LayerNorm(hidden_size)
|
18 |
+
|
19 |
+
def forward(self, inputs):
|
20 |
+
x = inputs
|
21 |
+
|
22 |
+
# Layer Normalization
|
23 |
+
x_norm = self.layer_norm_mha(x)
|
24 |
+
|
25 |
+
# Multi-head attention
|
26 |
+
if self.attention_map is True:
|
27 |
+
y, weights = self.multi_head_attention(x_norm, x_norm, x_norm)
|
28 |
+
else:
|
29 |
+
y = self.multi_head_attention(x_norm, x_norm, x_norm)
|
30 |
+
|
31 |
+
# Dropout and residual
|
32 |
+
x = self.dropout(x + y)
|
33 |
+
|
34 |
+
# Layer Normalization
|
35 |
+
x_norm = self.layer_norm_ffn(x)
|
36 |
+
|
37 |
+
# Positionwise Feedforward
|
38 |
+
y = self.positionwise_convolution(x_norm)
|
39 |
+
|
40 |
+
# Dropout and residual
|
41 |
+
y = self.dropout(x + y)
|
42 |
+
|
43 |
+
if self.attention_map is True:
|
44 |
+
return y, weights
|
45 |
+
return y
|
46 |
+
|
47 |
+
class bi_directional_self_attention(nn.Module):
|
48 |
+
def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads, max_length,
|
49 |
+
layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0):
|
50 |
+
|
51 |
+
super(bi_directional_self_attention, self).__init__()
|
52 |
+
|
53 |
+
self.weights_list = list()
|
54 |
+
|
55 |
+
params = (hidden_size,
|
56 |
+
total_key_depth or hidden_size,
|
57 |
+
total_value_depth or hidden_size,
|
58 |
+
filter_size,
|
59 |
+
num_heads,
|
60 |
+
_gen_bias_mask(max_length),
|
61 |
+
layer_dropout,
|
62 |
+
attention_dropout,
|
63 |
+
relu_dropout,
|
64 |
+
True)
|
65 |
+
|
66 |
+
self.attn_block = self_attention_block(*params)
|
67 |
+
|
68 |
+
params = (hidden_size,
|
69 |
+
total_key_depth or hidden_size,
|
70 |
+
total_value_depth or hidden_size,
|
71 |
+
filter_size,
|
72 |
+
num_heads,
|
73 |
+
torch.transpose(_gen_bias_mask(max_length), dim0=2, dim1=3),
|
74 |
+
layer_dropout,
|
75 |
+
attention_dropout,
|
76 |
+
relu_dropout,
|
77 |
+
True)
|
78 |
+
|
79 |
+
self.backward_attn_block = self_attention_block(*params)
|
80 |
+
|
81 |
+
self.linear = nn.Linear(hidden_size*2, hidden_size)
|
82 |
+
|
83 |
+
def forward(self, inputs):
|
84 |
+
x, list = inputs
|
85 |
+
|
86 |
+
# Forward Self-attention Block
|
87 |
+
encoder_outputs, weights = self.attn_block(x)
|
88 |
+
# Backward Self-attention Block
|
89 |
+
reverse_outputs, reverse_weights = self.backward_attn_block(x)
|
90 |
+
# Concatenation and Fully-connected Layer
|
91 |
+
outputs = torch.cat((encoder_outputs, reverse_outputs), dim=2)
|
92 |
+
y = self.linear(outputs)
|
93 |
+
|
94 |
+
# Attention weights for Visualization
|
95 |
+
self.weights_list = list
|
96 |
+
self.weights_list.append(weights)
|
97 |
+
self.weights_list.append(reverse_weights)
|
98 |
+
return y, self.weights_list
|
99 |
+
|
100 |
+
class bi_directional_self_attention_layers(nn.Module):
|
101 |
+
def __init__(self, embedding_size, hidden_size, num_layers, num_heads, total_key_depth, total_value_depth,
|
102 |
+
filter_size, max_length=100, input_dropout=0.0, layer_dropout=0.0,
|
103 |
+
attention_dropout=0.0, relu_dropout=0.0):
|
104 |
+
super(bi_directional_self_attention_layers, self).__init__()
|
105 |
+
|
106 |
+
self.timing_signal = _gen_timing_signal(max_length, hidden_size)
|
107 |
+
params = (hidden_size,
|
108 |
+
total_key_depth or hidden_size,
|
109 |
+
total_value_depth or hidden_size,
|
110 |
+
filter_size,
|
111 |
+
num_heads,
|
112 |
+
max_length,
|
113 |
+
layer_dropout,
|
114 |
+
attention_dropout,
|
115 |
+
relu_dropout)
|
116 |
+
self.embedding_proj = nn.Linear(embedding_size, hidden_size, bias=False)
|
117 |
+
self.self_attn_layers = nn.Sequential(*[bi_directional_self_attention(*params) for l in range(num_layers)])
|
118 |
+
self.layer_norm = LayerNorm(hidden_size)
|
119 |
+
self.input_dropout = nn.Dropout(input_dropout)
|
120 |
+
|
121 |
+
def forward(self, inputs):
|
122 |
+
# Add input dropout
|
123 |
+
x = self.input_dropout(inputs)
|
124 |
+
|
125 |
+
# Project to hidden size
|
126 |
+
x = self.embedding_proj(x)
|
127 |
+
|
128 |
+
# Add timing signal
|
129 |
+
x += self.timing_signal[:, :inputs.shape[1], :].type_as(inputs.data)
|
130 |
+
|
131 |
+
# A Stack of Bi-directional Self-attention Layers
|
132 |
+
y, weights_list = self.self_attn_layers((x, []))
|
133 |
+
|
134 |
+
# Layer Normalization
|
135 |
+
y = self.layer_norm(y)
|
136 |
+
return y, weights_list
|
137 |
+
|
138 |
+
class BTC_model(nn.Module):
|
139 |
+
def __init__(self, config):
|
140 |
+
super(BTC_model, self).__init__()
|
141 |
+
|
142 |
+
self.timestep = config['timestep']
|
143 |
+
self.probs_out = config['probs_out']
|
144 |
+
|
145 |
+
params = (config['feature_size'],
|
146 |
+
config['hidden_size'],
|
147 |
+
config['num_layers'],
|
148 |
+
config['num_heads'],
|
149 |
+
config['total_key_depth'],
|
150 |
+
config['total_value_depth'],
|
151 |
+
config['filter_size'],
|
152 |
+
config['timestep'],
|
153 |
+
config['input_dropout'],
|
154 |
+
config['layer_dropout'],
|
155 |
+
config['attention_dropout'],
|
156 |
+
config['relu_dropout'])
|
157 |
+
|
158 |
+
self.self_attn_layers = bi_directional_self_attention_layers(*params)
|
159 |
+
self.output_layer = SoftmaxOutputLayer(hidden_size=config['hidden_size'], output_size=config['num_chords'], probs_out=config['probs_out'])
|
160 |
+
|
161 |
+
def forward(self, x, labels):
|
162 |
+
labels = labels.view(-1, self.timestep)
|
163 |
+
# Output of Bi-directional Self-attention Layers
|
164 |
+
self_attn_output, weights_list = self.self_attn_layers(x)
|
165 |
+
|
166 |
+
# return logit values for CRF
|
167 |
+
if self.probs_out is True:
|
168 |
+
logits = self.output_layer(self_attn_output)
|
169 |
+
return logits
|
170 |
+
|
171 |
+
# Output layer and Soft-max
|
172 |
+
prediction,second = self.output_layer(self_attn_output)
|
173 |
+
prediction = prediction.view(-1)
|
174 |
+
second = second.view(-1)
|
175 |
+
|
176 |
+
# Loss Calculation
|
177 |
+
loss = self.output_layer.loss(self_attn_output, labels)
|
178 |
+
return prediction, loss, weights_list, second
|
179 |
+
|
180 |
+
if __name__ == "__main__":
|
181 |
+
config = HParams.load("run_config.yaml")
|
182 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
183 |
+
|
184 |
+
batch_size = 2
|
185 |
+
timestep = 108
|
186 |
+
feature_size = 144
|
187 |
+
num_chords = 25
|
188 |
+
|
189 |
+
features = torch.randn(batch_size,timestep,feature_size,requires_grad=True).to(device)
|
190 |
+
chords = torch.randint(25,(batch_size*timestep,)).to(device)
|
191 |
+
|
192 |
+
model = BTC_model(config=config.model).to(device)
|
193 |
+
|
194 |
+
prediction, loss, weights_list, second = model(features, chords)
|
195 |
+
print(prediction.size())
|
196 |
+
print(loss)
|
197 |
+
|
198 |
+
|
utils/chords.py
ADDED
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding: utf-8
|
2 |
+
"""
|
3 |
+
This module contains chord evaluation functionality.
|
4 |
+
|
5 |
+
It provides the evaluation measures used for the MIREX ACE task, and
|
6 |
+
tries to follow [1]_ and [2]_ as closely as possible.
|
7 |
+
|
8 |
+
Notes
|
9 |
+
-----
|
10 |
+
This implementation tries to follow the references and their implementation
|
11 |
+
(e.g., https://github.com/jpauwels/MusOOEvaluator for [2]_). However, there
|
12 |
+
are some known (and possibly some unknown) differences. If you find one not
|
13 |
+
listed in the following, please file an issue:
|
14 |
+
|
15 |
+
- Detected chord segments are adjusted to fit the length of the annotations.
|
16 |
+
In particular, this means that, if necessary, filler segments of 'no chord'
|
17 |
+
are added at beginnings and ends. This can result in different segmentation
|
18 |
+
scores compared to the original implementation.
|
19 |
+
|
20 |
+
References
|
21 |
+
----------
|
22 |
+
.. [1] Christopher Harte, "Towards Automatic Extraction of Harmony Information
|
23 |
+
from Music Signals." Dissertation,
|
24 |
+
Department for Electronic Engineering, Queen Mary University of London,
|
25 |
+
2010.
|
26 |
+
.. [2] Johan Pauwels and Geoffroy Peeters.
|
27 |
+
"Evaluating Automatically Estimated Chord Sequences."
|
28 |
+
In Proceedings of ICASSP 2013, Vancouver, Canada, 2013.
|
29 |
+
|
30 |
+
"""
|
31 |
+
|
32 |
+
import numpy as np
|
33 |
+
import pandas as pd
|
34 |
+
import mir_eval
|
35 |
+
|
36 |
+
|
37 |
+
CHORD_DTYPE = [('root', np.int),
|
38 |
+
('bass', np.int),
|
39 |
+
('intervals', np.int, (12,)),
|
40 |
+
('is_major',np.bool)]
|
41 |
+
|
42 |
+
CHORD_ANN_DTYPE = [('start', np.float),
|
43 |
+
('end', np.float),
|
44 |
+
('chord', CHORD_DTYPE)]
|
45 |
+
|
46 |
+
NO_CHORD = (-1, -1, np.zeros(12, dtype=np.int), False)
|
47 |
+
UNKNOWN_CHORD = (-1, -1, np.ones(12, dtype=np.int) * -1, False)
|
48 |
+
|
49 |
+
PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
|
50 |
+
|
51 |
+
|
52 |
+
def idx_to_chord(idx):
|
53 |
+
if idx == 24:
|
54 |
+
return "-"
|
55 |
+
elif idx == 25:
|
56 |
+
return u"\u03B5"
|
57 |
+
|
58 |
+
minmaj = idx % 2
|
59 |
+
root = idx // 2
|
60 |
+
|
61 |
+
return PITCH_CLASS[root] + ("M" if minmaj == 0 else "m")
|
62 |
+
|
63 |
+
class Chords:
|
64 |
+
|
65 |
+
def __init__(self):
|
66 |
+
self._shorthands = {
|
67 |
+
'maj': self.interval_list('(1,3,5)'),
|
68 |
+
'min': self.interval_list('(1,b3,5)'),
|
69 |
+
'dim': self.interval_list('(1,b3,b5)'),
|
70 |
+
'aug': self.interval_list('(1,3,#5)'),
|
71 |
+
'maj7': self.interval_list('(1,3,5,7)'),
|
72 |
+
'min7': self.interval_list('(1,b3,5,b7)'),
|
73 |
+
'7': self.interval_list('(1,3,5,b7)'),
|
74 |
+
'6': self.interval_list('(1,6)'), # custom
|
75 |
+
'5': self.interval_list('(1,5)'),
|
76 |
+
'4': self.interval_list('(1,4)'), # custom
|
77 |
+
'1': self.interval_list('(1)'),
|
78 |
+
'dim7': self.interval_list('(1,b3,b5,bb7)'),
|
79 |
+
'hdim7': self.interval_list('(1,b3,b5,b7)'),
|
80 |
+
'minmaj7': self.interval_list('(1,b3,5,7)'),
|
81 |
+
'maj6': self.interval_list('(1,3,5,6)'),
|
82 |
+
'min6': self.interval_list('(1,b3,5,6)'),
|
83 |
+
'9': self.interval_list('(1,3,5,b7,9)'),
|
84 |
+
'maj9': self.interval_list('(1,3,5,7,9)'),
|
85 |
+
'min9': self.interval_list('(1,b3,5,b7,9)'),
|
86 |
+
'sus2': self.interval_list('(1,2,5)'),
|
87 |
+
'sus4': self.interval_list('(1,4,5)'),
|
88 |
+
'11': self.interval_list('(1,3,5,b7,9,11)'),
|
89 |
+
'min11': self.interval_list('(1,b3,5,b7,9,11)'),
|
90 |
+
'13': self.interval_list('(1,3,5,b7,13)'),
|
91 |
+
'maj13': self.interval_list('(1,3,5,7,13)'),
|
92 |
+
'min13': self.interval_list('(1,b3,5,b7,13)')
|
93 |
+
}
|
94 |
+
|
95 |
+
def chords(self, labels):
|
96 |
+
|
97 |
+
"""
|
98 |
+
Transform a list of chord labels into an array of internal numeric
|
99 |
+
representations.
|
100 |
+
|
101 |
+
Parameters
|
102 |
+
----------
|
103 |
+
labels : list
|
104 |
+
List of chord labels (str).
|
105 |
+
|
106 |
+
Returns
|
107 |
+
-------
|
108 |
+
chords : numpy.array
|
109 |
+
Structured array with columns 'root', 'bass', and 'intervals',
|
110 |
+
containing a numeric representation of chords.
|
111 |
+
|
112 |
+
"""
|
113 |
+
crds = np.zeros(len(labels), dtype=CHORD_DTYPE)
|
114 |
+
cache = {}
|
115 |
+
for i, lbl in enumerate(labels):
|
116 |
+
cv = cache.get(lbl, None)
|
117 |
+
if cv is None:
|
118 |
+
cv = self.chord(lbl)
|
119 |
+
cache[lbl] = cv
|
120 |
+
crds[i] = cv
|
121 |
+
|
122 |
+
return crds
|
123 |
+
|
124 |
+
def label_error_modify(self, label):
|
125 |
+
if label == 'Emin/4': label = 'E:min/4'
|
126 |
+
elif label == 'A7/3': label = 'A:7/3'
|
127 |
+
elif label == 'Bb7/3': label = 'Bb:7/3'
|
128 |
+
elif label == 'Bb7/5': label = 'Bb:7/5'
|
129 |
+
elif label.find(':') == -1:
|
130 |
+
if label.find('min') != -1:
|
131 |
+
label = label[:label.find('min')] + ':' + label[label.find('min'):]
|
132 |
+
return label
|
133 |
+
|
134 |
+
def chord(self, label):
|
135 |
+
"""
|
136 |
+
Transform a chord label into the internal numeric represenation of
|
137 |
+
(root, bass, intervals array).
|
138 |
+
|
139 |
+
Parameters
|
140 |
+
----------
|
141 |
+
label : str
|
142 |
+
Chord label.
|
143 |
+
|
144 |
+
Returns
|
145 |
+
-------
|
146 |
+
chord : tuple
|
147 |
+
Numeric representation of the chord: (root, bass, intervals array).
|
148 |
+
|
149 |
+
"""
|
150 |
+
|
151 |
+
try:
|
152 |
+
is_major = False
|
153 |
+
|
154 |
+
if label == 'N':
|
155 |
+
return NO_CHORD
|
156 |
+
if label == 'X':
|
157 |
+
return UNKNOWN_CHORD
|
158 |
+
|
159 |
+
label = self.label_error_modify(label)
|
160 |
+
|
161 |
+
c_idx = label.find(':')
|
162 |
+
s_idx = label.find('/')
|
163 |
+
|
164 |
+
if c_idx == -1:
|
165 |
+
quality_str = 'maj'
|
166 |
+
if s_idx == -1:
|
167 |
+
root_str = label
|
168 |
+
bass_str = ''
|
169 |
+
else:
|
170 |
+
root_str = label[:s_idx]
|
171 |
+
bass_str = label[s_idx + 1:]
|
172 |
+
else:
|
173 |
+
root_str = label[:c_idx]
|
174 |
+
if s_idx == -1:
|
175 |
+
quality_str = label[c_idx + 1:]
|
176 |
+
bass_str = ''
|
177 |
+
else:
|
178 |
+
quality_str = label[c_idx + 1:s_idx]
|
179 |
+
bass_str = label[s_idx + 1:]
|
180 |
+
|
181 |
+
root = self.pitch(root_str)
|
182 |
+
bass = self.interval(bass_str) if bass_str else 0
|
183 |
+
ivs = self.chord_intervals(quality_str)
|
184 |
+
ivs[bass] = 1
|
185 |
+
|
186 |
+
if 'min' in quality_str:
|
187 |
+
is_major = False
|
188 |
+
else:
|
189 |
+
is_major = True
|
190 |
+
|
191 |
+
except Exception as e:
|
192 |
+
print(e, label)
|
193 |
+
|
194 |
+
return root, bass, ivs, is_major
|
195 |
+
|
196 |
+
_l = [0, 1, 1, 0, 1, 1, 1]
|
197 |
+
_chroma_id = (np.arange(len(_l) * 2) + 1) + np.array(_l + _l).cumsum() - 1
|
198 |
+
|
199 |
+
def modify(self, base_pitch, modifier):
|
200 |
+
"""
|
201 |
+
Modify a pitch class in integer representation by a given modifier string.
|
202 |
+
|
203 |
+
A modifier string can be any sequence of 'b' (one semitone down)
|
204 |
+
and '#' (one semitone up).
|
205 |
+
|
206 |
+
Parameters
|
207 |
+
----------
|
208 |
+
base_pitch : int
|
209 |
+
Pitch class as integer.
|
210 |
+
modifier : str
|
211 |
+
String of modifiers ('b' or '#').
|
212 |
+
|
213 |
+
Returns
|
214 |
+
-------
|
215 |
+
modified_pitch : int
|
216 |
+
Modified root note.
|
217 |
+
|
218 |
+
"""
|
219 |
+
for m in modifier:
|
220 |
+
if m == 'b':
|
221 |
+
base_pitch -= 1
|
222 |
+
elif m == '#':
|
223 |
+
base_pitch += 1
|
224 |
+
else:
|
225 |
+
raise ValueError('Unknown modifier: {}'.format(m))
|
226 |
+
return base_pitch
|
227 |
+
|
228 |
+
def pitch(self, pitch_str):
|
229 |
+
"""
|
230 |
+
Convert a string representation of a pitch class (consisting of root
|
231 |
+
note and modifiers) to an integer representation.
|
232 |
+
|
233 |
+
Parameters
|
234 |
+
----------
|
235 |
+
pitch_str : str
|
236 |
+
String representation of a pitch class.
|
237 |
+
|
238 |
+
Returns
|
239 |
+
-------
|
240 |
+
pitch : int
|
241 |
+
Integer representation of a pitch class.
|
242 |
+
|
243 |
+
"""
|
244 |
+
return self.modify(self._chroma_id[(ord(pitch_str[0]) - ord('C')) % 7],
|
245 |
+
pitch_str[1:]) % 12
|
246 |
+
|
247 |
+
def interval(self, interval_str):
|
248 |
+
"""
|
249 |
+
Convert a string representation of a musical interval into a pitch class
|
250 |
+
(e.g. a minor seventh 'b7' into 10, because it is 10 semitones above its
|
251 |
+
base note).
|
252 |
+
|
253 |
+
Parameters
|
254 |
+
----------
|
255 |
+
interval_str : str
|
256 |
+
Musical interval.
|
257 |
+
|
258 |
+
Returns
|
259 |
+
-------
|
260 |
+
pitch_class : int
|
261 |
+
Number of semitones to base note of interval.
|
262 |
+
|
263 |
+
"""
|
264 |
+
for i, c in enumerate(interval_str):
|
265 |
+
if c.isdigit():
|
266 |
+
return self.modify(self._chroma_id[int(interval_str[i:]) - 1],
|
267 |
+
interval_str[:i]) % 12
|
268 |
+
|
269 |
+
def interval_list(self, intervals_str, given_pitch_classes=None):
|
270 |
+
"""
|
271 |
+
Convert a list of intervals given as string to a binary pitch class
|
272 |
+
representation. For example, 'b3, 5' would become
|
273 |
+
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0].
|
274 |
+
|
275 |
+
Parameters
|
276 |
+
----------
|
277 |
+
intervals_str : str
|
278 |
+
List of intervals as comma-separated string (e.g. 'b3, 5').
|
279 |
+
given_pitch_classes : None or numpy array
|
280 |
+
If None, start with empty pitch class array, if numpy array of length
|
281 |
+
12, this array will be modified.
|
282 |
+
|
283 |
+
Returns
|
284 |
+
-------
|
285 |
+
pitch_classes : numpy array
|
286 |
+
Binary pitch class representation of intervals.
|
287 |
+
|
288 |
+
"""
|
289 |
+
if given_pitch_classes is None:
|
290 |
+
given_pitch_classes = np.zeros(12, dtype=np.int)
|
291 |
+
for int_def in intervals_str[1:-1].split(','):
|
292 |
+
int_def = int_def.strip()
|
293 |
+
if int_def[0] == '*':
|
294 |
+
given_pitch_classes[self.interval(int_def[1:])] = 0
|
295 |
+
else:
|
296 |
+
given_pitch_classes[self.interval(int_def)] = 1
|
297 |
+
return given_pitch_classes
|
298 |
+
|
299 |
+
# mapping of shorthand interval notations to the actual interval representation
|
300 |
+
|
301 |
+
def chord_intervals(self, quality_str):
|
302 |
+
"""
|
303 |
+
Convert a chord quality string to a pitch class representation. For
|
304 |
+
example, 'maj' becomes [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0].
|
305 |
+
|
306 |
+
Parameters
|
307 |
+
----------
|
308 |
+
quality_str : str
|
309 |
+
String defining the chord quality.
|
310 |
+
|
311 |
+
Returns
|
312 |
+
-------
|
313 |
+
pitch_classes : numpy array
|
314 |
+
Binary pitch class representation of chord quality.
|
315 |
+
|
316 |
+
"""
|
317 |
+
list_idx = quality_str.find('(')
|
318 |
+
if list_idx == -1:
|
319 |
+
return self._shorthands[quality_str].copy()
|
320 |
+
if list_idx != 0:
|
321 |
+
ivs = self._shorthands[quality_str[:list_idx]].copy()
|
322 |
+
else:
|
323 |
+
ivs = np.zeros(12, dtype=np.int)
|
324 |
+
|
325 |
+
|
326 |
+
return self.interval_list(quality_str[list_idx:], ivs)
|
327 |
+
|
328 |
+
def load_chords(self, filename):
|
329 |
+
"""
|
330 |
+
Load chords from a text file.
|
331 |
+
|
332 |
+
The chord must follow the syntax defined in [1]_.
|
333 |
+
|
334 |
+
Parameters
|
335 |
+
----------
|
336 |
+
filename : str
|
337 |
+
File containing chord segments.
|
338 |
+
|
339 |
+
Returns
|
340 |
+
-------
|
341 |
+
crds : numpy structured array
|
342 |
+
Structured array with columns "start", "end", and "chord",
|
343 |
+
containing the beginning, end, and chord definition of chord
|
344 |
+
segments.
|
345 |
+
|
346 |
+
References
|
347 |
+
----------
|
348 |
+
.. [1] Christopher Harte, "Towards Automatic Extraction of Harmony
|
349 |
+
Information from Music Signals." Dissertation,
|
350 |
+
Department for Electronic Engineering, Queen Mary University of
|
351 |
+
London, 2010.
|
352 |
+
|
353 |
+
"""
|
354 |
+
start, end, chord_labels = [], [], []
|
355 |
+
with open(filename, 'r') as f:
|
356 |
+
for line in f:
|
357 |
+
if line:
|
358 |
+
|
359 |
+
splits = line.split()
|
360 |
+
if len(splits) == 3:
|
361 |
+
|
362 |
+
s = splits[0]
|
363 |
+
e = splits[1]
|
364 |
+
l = splits[2]
|
365 |
+
|
366 |
+
start.append(float(s))
|
367 |
+
end.append(float(e))
|
368 |
+
chord_labels.append(l)
|
369 |
+
|
370 |
+
crds = np.zeros(len(start), dtype=CHORD_ANN_DTYPE)
|
371 |
+
crds['start'] = start
|
372 |
+
crds['end'] = end
|
373 |
+
crds['chord'] = self.chords(chord_labels)
|
374 |
+
|
375 |
+
return crds
|
376 |
+
|
377 |
+
def reduce_to_triads(self, chords, keep_bass=False):
|
378 |
+
"""
|
379 |
+
Reduce chords to triads.
|
380 |
+
|
381 |
+
The function follows the reduction rules implemented in [1]_. If a chord
|
382 |
+
chord does not contain a third, major second or fourth, it is reduced to
|
383 |
+
a power chord. If it does not contain neither a third nor a fifth, it is
|
384 |
+
reduced to a single note "chord".
|
385 |
+
|
386 |
+
Parameters
|
387 |
+
----------
|
388 |
+
chords : numpy structured array
|
389 |
+
Chords to be reduced.
|
390 |
+
keep_bass : bool
|
391 |
+
Indicates whether to keep the bass note or set it to 0.
|
392 |
+
|
393 |
+
Returns
|
394 |
+
-------
|
395 |
+
reduced_chords : numpy structured array
|
396 |
+
Chords reduced to triads.
|
397 |
+
|
398 |
+
References
|
399 |
+
----------
|
400 |
+
.. [1] Johan Pauwels and Geoffroy Peeters.
|
401 |
+
"Evaluating Automatically Estimated Chord Sequences."
|
402 |
+
In Proceedings of ICASSP 2013, Vancouver, Canada, 2013.
|
403 |
+
|
404 |
+
"""
|
405 |
+
unison = chords['intervals'][:, 0].astype(bool)
|
406 |
+
maj_sec = chords['intervals'][:, 2].astype(bool)
|
407 |
+
min_third = chords['intervals'][:, 3].astype(bool)
|
408 |
+
maj_third = chords['intervals'][:, 4].astype(bool)
|
409 |
+
perf_fourth = chords['intervals'][:, 5].astype(bool)
|
410 |
+
dim_fifth = chords['intervals'][:, 6].astype(bool)
|
411 |
+
perf_fifth = chords['intervals'][:, 7].astype(bool)
|
412 |
+
aug_fifth = chords['intervals'][:, 8].astype(bool)
|
413 |
+
no_chord = (chords['intervals'] == NO_CHORD[-1]).all(axis=1)
|
414 |
+
|
415 |
+
reduced_chords = chords.copy()
|
416 |
+
ivs = reduced_chords['intervals']
|
417 |
+
|
418 |
+
ivs[~no_chord] = self.interval_list('(1)')
|
419 |
+
ivs[unison & perf_fifth] = self.interval_list('(1,5)')
|
420 |
+
ivs[~perf_fourth & maj_sec] = self._shorthands['sus2']
|
421 |
+
ivs[perf_fourth & ~maj_sec] = self._shorthands['sus4']
|
422 |
+
|
423 |
+
ivs[min_third] = self._shorthands['min']
|
424 |
+
ivs[min_third & aug_fifth & ~perf_fifth] = self.interval_list('(1,b3,#5)')
|
425 |
+
ivs[min_third & dim_fifth & ~perf_fifth] = self._shorthands['dim']
|
426 |
+
|
427 |
+
ivs[maj_third] = self._shorthands['maj']
|
428 |
+
ivs[maj_third & dim_fifth & ~perf_fifth] = self.interval_list('(1,3,b5)')
|
429 |
+
ivs[maj_third & aug_fifth & ~perf_fifth] = self._shorthands['aug']
|
430 |
+
|
431 |
+
if not keep_bass:
|
432 |
+
reduced_chords['bass'] = 0
|
433 |
+
else:
|
434 |
+
# remove bass notes if they are not part of the intervals anymore
|
435 |
+
reduced_chords['bass'] *= ivs[range(len(reduced_chords)),
|
436 |
+
reduced_chords['bass']]
|
437 |
+
# keep -1 in bass for no chords
|
438 |
+
reduced_chords['bass'][no_chord] = -1
|
439 |
+
|
440 |
+
return reduced_chords
|
441 |
+
|
442 |
+
def convert_to_id(self, root, is_major):
|
443 |
+
if root == -1:
|
444 |
+
return 24
|
445 |
+
else:
|
446 |
+
if is_major:
|
447 |
+
return root * 2
|
448 |
+
else:
|
449 |
+
return root * 2 + 1
|
450 |
+
|
451 |
+
def get_converted_chord(self, filename):
|
452 |
+
loaded_chord = self.load_chords(filename)
|
453 |
+
triads = self.reduce_to_triads(loaded_chord['chord'])
|
454 |
+
|
455 |
+
df = self.assign_chord_id(triads)
|
456 |
+
df['start'] = loaded_chord['start']
|
457 |
+
df['end'] = loaded_chord['end']
|
458 |
+
|
459 |
+
return df
|
460 |
+
|
461 |
+
def assign_chord_id(self, entry):
|
462 |
+
# maj, min chord only
|
463 |
+
# if you want to add other chord, change this part and get_converted_chord(reduce_to_triads)
|
464 |
+
df = pd.DataFrame(data=entry[['root', 'is_major']])
|
465 |
+
df['chord_id'] = df.apply(lambda row: self.convert_to_id(row['root'], row['is_major']), axis=1)
|
466 |
+
return df
|
467 |
+
|
468 |
+
def convert_to_id_voca(self, root, quality):
|
469 |
+
if root == -1:
|
470 |
+
return 169
|
471 |
+
else:
|
472 |
+
if quality == 'min':
|
473 |
+
return root * 14
|
474 |
+
elif quality == 'maj':
|
475 |
+
return root * 14 + 1
|
476 |
+
elif quality == 'dim':
|
477 |
+
return root * 14 + 2
|
478 |
+
elif quality == 'aug':
|
479 |
+
return root * 14 + 3
|
480 |
+
elif quality == 'min6':
|
481 |
+
return root * 14 + 4
|
482 |
+
elif quality == 'maj6':
|
483 |
+
return root * 14 + 5
|
484 |
+
elif quality == 'min7':
|
485 |
+
return root * 14 + 6
|
486 |
+
elif quality == 'minmaj7':
|
487 |
+
return root * 14 + 7
|
488 |
+
elif quality == 'maj7':
|
489 |
+
return root * 14 + 8
|
490 |
+
elif quality == '7':
|
491 |
+
return root * 14 + 9
|
492 |
+
elif quality == 'dim7':
|
493 |
+
return root * 14 + 10
|
494 |
+
elif quality == 'hdim7':
|
495 |
+
return root * 14 + 11
|
496 |
+
elif quality == 'sus2':
|
497 |
+
return root * 14 + 12
|
498 |
+
elif quality == 'sus4':
|
499 |
+
return root * 14 + 13
|
500 |
+
else:
|
501 |
+
return 168
|
502 |
+
|
503 |
+
def get_converted_chord_voca(self, filename):
|
504 |
+
loaded_chord = self.load_chords(filename)
|
505 |
+
triads = self.reduce_to_triads(loaded_chord['chord'])
|
506 |
+
df = pd.DataFrame(data=triads[['root', 'is_major']])
|
507 |
+
|
508 |
+
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(filename)
|
509 |
+
ref_labels = self.lab_file_error_modify(ref_labels)
|
510 |
+
idxs = list()
|
511 |
+
for i in ref_labels:
|
512 |
+
chord_root, quality, scale_degrees, bass = mir_eval.chord.split(i, reduce_extended_chords=True)
|
513 |
+
root, bass, ivs, is_major = self.chord(i)
|
514 |
+
idxs.append(self.convert_to_id_voca(root=root, quality=quality))
|
515 |
+
df['chord_id'] = idxs
|
516 |
+
|
517 |
+
df['start'] = loaded_chord['start']
|
518 |
+
df['end'] = loaded_chord['end']
|
519 |
+
|
520 |
+
return df
|
521 |
+
|
522 |
+
def lab_file_error_modify(self, ref_labels):
|
523 |
+
for i in range(len(ref_labels)):
|
524 |
+
if ref_labels[i][-2:] == ':4':
|
525 |
+
ref_labels[i] = ref_labels[i].replace(':4', ':sus4')
|
526 |
+
elif ref_labels[i][-2:] == ':6':
|
527 |
+
ref_labels[i] = ref_labels[i].replace(':6', ':maj6')
|
528 |
+
elif ref_labels[i][-4:] == ':6/2':
|
529 |
+
ref_labels[i] = ref_labels[i].replace(':6/2', ':maj6/2')
|
530 |
+
elif ref_labels[i] == 'Emin/4':
|
531 |
+
ref_labels[i] = 'E:min/4'
|
532 |
+
elif ref_labels[i] == 'A7/3':
|
533 |
+
ref_labels[i] = 'A:7/3'
|
534 |
+
elif ref_labels[i] == 'Bb7/3':
|
535 |
+
ref_labels[i] = 'Bb:7/3'
|
536 |
+
elif ref_labels[i] == 'Bb7/5':
|
537 |
+
ref_labels[i] = 'Bb:7/5'
|
538 |
+
elif ref_labels[i].find(':') == -1:
|
539 |
+
if ref_labels[i].find('min') != -1:
|
540 |
+
ref_labels[i] = ref_labels[i][:ref_labels[i].find('min')] + ':' + ref_labels[i][ref_labels[i].find('min'):]
|
541 |
+
return ref_labels
|
542 |
+
|
utils/constants.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
### DEPRECATED - use hydra conf instead ######
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import os
|
9 |
+
|
10 |
+
# --------------------------------------- #
|
11 |
+
VERSION = "1.24"
|
12 |
+
|
13 |
+
# --------------------------------------- #
|
14 |
+
ENCODER = "MERT"
|
15 |
+
|
16 |
+
# - - -
|
17 |
+
# MERT
|
18 |
+
# M2L
|
19 |
+
# LIBROSA
|
20 |
+
# - - -
|
21 |
+
# Encodec
|
22 |
+
# DAC
|
23 |
+
|
24 |
+
# --------------------------------------- #
|
25 |
+
|
26 |
+
SEGMENT = "all"
|
27 |
+
# all
|
28 |
+
# f10s - first 10s
|
29 |
+
# f30s - first 30s
|
30 |
+
# 10s
|
31 |
+
# 30s
|
32 |
+
|
33 |
+
AGGREGATION_METHOD = "mean"
|
34 |
+
# mean
|
35 |
+
# median
|
36 |
+
# 80th_percentile
|
37 |
+
# max
|
38 |
+
|
39 |
+
# --------------------------------------- #
|
40 |
+
CLASSIFIER = "linear-mt"
|
41 |
+
# transformer
|
42 |
+
# linear
|
43 |
+
# linear-small
|
44 |
+
# linear-multitask
|
45 |
+
# linear-small-multitask
|
46 |
+
# linear-mt (mert-like classifier)
|
47 |
+
#
|
48 |
+
# --------------------------------------- #
|
49 |
+
CHECKPOINT = "tb_logs/train_audio_classification/version_110/checkpoints/21-0.1202.ckpt"
|
50 |
+
# --------------------------------------- #
|
51 |
+
BATCH_SIZE = 8
|
52 |
+
N_EPOCHS = 50
|
53 |
+
|
54 |
+
# --------------------------------------- #
|
55 |
+
GENRE_CLASS_SIZE = 87
|
56 |
+
MOOD_CLASS_SIZE = 56
|
57 |
+
INSTR_CLASS_SIZE = 40
|
58 |
+
DAC_LATENTS_SIZE = 72
|
59 |
+
DAC_RVQ_SIZE = 9
|
60 |
+
# --------------------------------------- #
|
utils/custom_early_stopping.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# custom_early_stopping.py
|
2 |
+
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
5 |
+
|
6 |
+
|
7 |
+
class MultiMetricEarlyStopping(EarlyStopping):
|
8 |
+
def __init__(self, monitor_mood, monitor_va, patience, min_delta, mode="min"):
|
9 |
+
super().__init__(monitor=None, patience=patience, min_delta=min_delta, mode=mode)
|
10 |
+
self.monitor_mood = monitor_mood
|
11 |
+
self.monitor_va = monitor_va
|
12 |
+
self.patience = patience
|
13 |
+
self.min_delta = min_delta
|
14 |
+
self.mode = mode
|
15 |
+
|
16 |
+
# Initialize tracking variables
|
17 |
+
self.wait_mood = 0
|
18 |
+
self.wait_va = 0
|
19 |
+
self.best_mood = float('inf') if mode == "min" else -float('inf')
|
20 |
+
self.best_va = float('inf') if mode == "min" else -float('inf')
|
21 |
+
|
22 |
+
def _check_stop(self, current, best, wait):
|
23 |
+
if self.mode == "min" and current < best - self.min_delta:
|
24 |
+
return current, 0
|
25 |
+
elif self.mode == "max" and current > best + self.min_delta:
|
26 |
+
return current, 0
|
27 |
+
else:
|
28 |
+
return best, wait + 1
|
29 |
+
|
30 |
+
def on_validation_epoch_end(self, trainer, pl_module):
|
31 |
+
logs = trainer.callback_metrics
|
32 |
+
|
33 |
+
if self.monitor_mood not in logs or self.monitor_va not in logs:
|
34 |
+
raise RuntimeError(f"Metrics {self.monitor_mood} or {self.monitor_va} not available.")
|
35 |
+
|
36 |
+
# Get current values for the monitored metrics
|
37 |
+
current_mood = logs[self.monitor_mood].item()
|
38 |
+
current_va = logs[self.monitor_va].item()
|
39 |
+
|
40 |
+
# Check stopping conditions for both metrics
|
41 |
+
self.best_mood, self.wait_mood = self._check_stop(current_mood, self.best_mood, self.wait_mood)
|
42 |
+
self.best_va, self.wait_va = self._check_stop(current_va, self.best_va, self.wait_va)
|
43 |
+
|
44 |
+
# Stop if patience exceeded for both metrics
|
45 |
+
if self.wait_mood > self.patience and self.wait_va > self.patience:
|
46 |
+
self.stopped_epoch = trainer.current_epoch
|
47 |
+
trainer.should_stop = True
|
48 |
+
|
49 |
+
# # custom_early_stopping.py
|
50 |
+
|
51 |
+
# import pytorch_lightning as pl
|
52 |
+
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
53 |
+
|
54 |
+
# class MultiMetricEarlyStopping(EarlyStopping):
|
55 |
+
# def __init__(self, monitor_mood: str, monitor_va: str, patience: int = 10, min_delta: float = 0.0, mode: str = "min"):
|
56 |
+
# super().__init__(monitor=None, patience=patience, min_delta=min_delta, mode=mode)
|
57 |
+
# self.monitor_mood = monitor_mood
|
58 |
+
# self.monitor_va = monitor_va
|
59 |
+
# self.wait_mood = 0
|
60 |
+
# self.wait_va = 0
|
61 |
+
# self.best_mood_score = None
|
62 |
+
# self.best_va_score = None
|
63 |
+
# self.patience = patience
|
64 |
+
# self.stopped_epoch = 0
|
65 |
+
|
66 |
+
# def on_validation_end(self, trainer, pl_module):
|
67 |
+
# current_mood = trainer.callback_metrics.get(self.monitor_mood)
|
68 |
+
# current_va = trainer.callback_metrics.get(self.monitor_va)
|
69 |
+
|
70 |
+
# # Check if current_mood improved
|
71 |
+
# if self.best_mood_score is None or self._compare(current_mood, self.best_mood_score):
|
72 |
+
# self.best_mood_score = current_mood
|
73 |
+
# self.wait_mood = 0
|
74 |
+
# else:
|
75 |
+
# self.wait_mood += 1
|
76 |
+
|
77 |
+
# # Check if current_va improved
|
78 |
+
# if self.best_va_score is None or self._compare(current_va, self.best_va_score):
|
79 |
+
# self.best_va_score = current_va
|
80 |
+
# self.wait_va = 0
|
81 |
+
# else:
|
82 |
+
# self.wait_va += 1
|
83 |
+
|
84 |
+
# # If both metrics are stagnant for patience epochs, stop training
|
85 |
+
# if self.wait_mood >= self.patience and self.wait_va >= self.patience:
|
86 |
+
# self.stopped_epoch = trainer.current_epoch
|
87 |
+
# trainer.should_stop = True
|
88 |
+
|
89 |
+
# def _compare(self, current, best):
|
90 |
+
# if self.mode == "min":
|
91 |
+
# return current < best - self.min_delta
|
92 |
+
# else:
|
93 |
+
# return current > best + self.min_delta
|
utils/hparams.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
|
3 |
+
|
4 |
+
# TODO: add function should be changed
|
5 |
+
class HParams(object):
|
6 |
+
# Hyperparameter class using yaml
|
7 |
+
def __init__(self, **kwargs):
|
8 |
+
self.__dict__ = kwargs
|
9 |
+
|
10 |
+
def add(self, **kwargs):
|
11 |
+
# change is needed - if key is existed, do not update.
|
12 |
+
self.__dict__.update(kwargs)
|
13 |
+
|
14 |
+
def update(self, **kwargs):
|
15 |
+
self.__dict__.update(kwargs)
|
16 |
+
return self
|
17 |
+
|
18 |
+
def save(self, path):
|
19 |
+
with open(path, 'w') as f:
|
20 |
+
yaml.dump(self.__dict__, f)
|
21 |
+
return self
|
22 |
+
|
23 |
+
def __repr__(self):
|
24 |
+
return '\nHyperparameters:\n' + '\n'.join([' {}={}'.format(k, v) for k, v in self.__dict__.items()])
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def load(cls, path):
|
28 |
+
with open(path, 'r') as f:
|
29 |
+
return cls(**yaml.load(f, Loader=yaml.FullLoader))
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == '__main__':
|
33 |
+
hparams = HParams.load('hparams.yaml')
|
34 |
+
print(hparams)
|
35 |
+
d = {"MemoryNetwork": 0, "c": 1}
|
36 |
+
hparams.add(**d)
|
37 |
+
print(hparams)
|
utils/logger.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
|
6 |
+
|
7 |
+
project_name = os.getcwd().split('/')[-1]
|
8 |
+
_logger = logging.getLogger(project_name)
|
9 |
+
_logger.addHandler(logging.StreamHandler())
|
10 |
+
|
11 |
+
def _log_prefix():
|
12 |
+
|
13 |
+
# Returns (filename, line number) for the stack frame.
|
14 |
+
def _get_file_line():
|
15 |
+
|
16 |
+
# pylint: disable=protected-access
|
17 |
+
# noinspection PyProtectedMember
|
18 |
+
f = sys._getframe()
|
19 |
+
# pylint: enable=protected-access
|
20 |
+
our_file = f.f_code.co_filename
|
21 |
+
f = f.f_back
|
22 |
+
while f:
|
23 |
+
code = f.f_code
|
24 |
+
if code.co_filename != our_file:
|
25 |
+
return code.co_filename, f.f_lineno
|
26 |
+
f = f.f_back
|
27 |
+
return '<unknown>', 0
|
28 |
+
|
29 |
+
# current time
|
30 |
+
now = time.time()
|
31 |
+
now_tuple = time.localtime(now)
|
32 |
+
now_millisecond = int(1e3 * (now % 1.0))
|
33 |
+
|
34 |
+
# current filename and line
|
35 |
+
filename, line = _get_file_line()
|
36 |
+
basename = os.path.basename(filename)
|
37 |
+
|
38 |
+
s = '%02d-%02d %02d:%02d:%02d.%03d %s:%d] ' % (
|
39 |
+
now_tuple[1], # month
|
40 |
+
now_tuple[2], # day
|
41 |
+
now_tuple[3], # hour
|
42 |
+
now_tuple[4], # min
|
43 |
+
now_tuple[5], # sec
|
44 |
+
now_millisecond,
|
45 |
+
basename,
|
46 |
+
line)
|
47 |
+
|
48 |
+
return s
|
49 |
+
|
50 |
+
|
51 |
+
def logging_verbosity(verbosity=0):
|
52 |
+
_logger.setLevel(verbosity)
|
53 |
+
|
54 |
+
|
55 |
+
def debug(msg, *args, **kwargs):
|
56 |
+
_logger.debug('D ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)
|
57 |
+
|
58 |
+
|
59 |
+
def info(msg, *args, **kwargs):
|
60 |
+
_logger.info('I ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)
|
61 |
+
|
62 |
+
|
63 |
+
def warn(msg, *args, **kwargs):
|
64 |
+
_logger.warning('W ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)
|
65 |
+
|
66 |
+
|
67 |
+
def error(msg, *args, **kwargs):
|
68 |
+
_logger.error('E ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)
|
69 |
+
|
70 |
+
|
71 |
+
def fatal(msg, *args, **kwargs):
|
72 |
+
_logger.fatal('F ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)
|
utils/mert.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from transformers import Wav2Vec2FeatureExtractor, AutoModel
|
4 |
+
|
5 |
+
class FeatureExtractorMERT:
|
6 |
+
def __init__(self, model_name="m-a-p/MERT-v1-95M", device = "None", sr=24000):
|
7 |
+
self.model_name = model_name
|
8 |
+
self.sr = sr
|
9 |
+
if device == "None":
|
10 |
+
use_cuda = torch.cuda.is_available()
|
11 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
12 |
+
else:
|
13 |
+
self.device = device
|
14 |
+
|
15 |
+
self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True).to(self.device)
|
16 |
+
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(self.model_name, trust_remote_code=True)
|
17 |
+
|
18 |
+
def extract_features_from_segment(self, segment, sample_rate, save_path):
|
19 |
+
input_audio = segment.float()
|
20 |
+
model_inputs = self.processor(input_audio, sampling_rate=sample_rate, return_tensors="pt")
|
21 |
+
model_inputs = model_inputs.to(self.device)
|
22 |
+
|
23 |
+
with torch.no_grad():
|
24 |
+
model_outputs = self.model(**model_inputs, output_hidden_states=True)
|
25 |
+
|
26 |
+
# Stack and process hidden states
|
27 |
+
all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()[1:, :, :].unsqueeze(0)
|
28 |
+
all_layer_hidden_states = all_layer_hidden_states.mean(dim=2)
|
29 |
+
features = all_layer_hidden_states.cpu().detach().numpy()
|
30 |
+
|
31 |
+
# Save features
|
32 |
+
np.save(save_path, features)
|
utils/mir_eval_modules.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import librosa
|
3 |
+
import mir_eval
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
|
7 |
+
idx2chord = ['C', 'C:min', 'C#', 'C#:min', 'D', 'D:min', 'D#', 'D#:min', 'E', 'E:min', 'F', 'F:min', 'F#',
|
8 |
+
'F#:min', 'G', 'G:min', 'G#', 'G#:min', 'A', 'A:min', 'A#', 'A#:min', 'B', 'B:min', 'N']
|
9 |
+
|
10 |
+
root_list = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
|
11 |
+
quality_list = ['min', 'maj', 'dim', 'aug', 'min6', 'maj6', 'min7', 'minmaj7', 'maj7', '7', 'dim7', 'hdim7', 'sus2', 'sus4']
|
12 |
+
|
13 |
+
def idx2voca_chord():
|
14 |
+
idx2voca_chord = {}
|
15 |
+
idx2voca_chord[169] = 'N'
|
16 |
+
idx2voca_chord[168] = 'X'
|
17 |
+
for i in range(168):
|
18 |
+
root = i // 14
|
19 |
+
root = root_list[root]
|
20 |
+
quality = i % 14
|
21 |
+
quality = quality_list[quality]
|
22 |
+
if i % 14 != 1:
|
23 |
+
chord = root + ':' + quality
|
24 |
+
else:
|
25 |
+
chord = root
|
26 |
+
idx2voca_chord[i] = chord
|
27 |
+
return idx2voca_chord
|
28 |
+
|
29 |
+
def audio_file_to_features(audio_file, config):
|
30 |
+
original_wav, sr = librosa.load(audio_file, sr=config.mp3['song_hz'], mono=True)
|
31 |
+
currunt_sec_hz = 0
|
32 |
+
while len(original_wav) > currunt_sec_hz + config.mp3['song_hz'] * config.mp3['inst_len']:
|
33 |
+
start_idx = int(currunt_sec_hz)
|
34 |
+
end_idx = int(currunt_sec_hz + config.mp3['song_hz'] * config.mp3['inst_len'])
|
35 |
+
tmp = librosa.cqt(original_wav[start_idx:end_idx], sr=sr, n_bins=config.feature['n_bins'], bins_per_octave=config.feature['bins_per_octave'], hop_length=config.feature['hop_length'])
|
36 |
+
if start_idx == 0:
|
37 |
+
feature = tmp
|
38 |
+
else:
|
39 |
+
feature = np.concatenate((feature, tmp), axis=1)
|
40 |
+
currunt_sec_hz = end_idx
|
41 |
+
tmp = librosa.cqt(original_wav[currunt_sec_hz:], sr=sr, n_bins=config.feature['n_bins'], bins_per_octave=config.feature['bins_per_octave'], hop_length=config.feature['hop_length'])
|
42 |
+
feature = np.concatenate((feature, tmp), axis=1)
|
43 |
+
feature = np.log(np.abs(feature) + 1e-6)
|
44 |
+
feature_per_second = config.mp3['inst_len'] / config.model['timestep']
|
45 |
+
song_length_second = len(original_wav)/config.mp3['song_hz']
|
46 |
+
return feature, feature_per_second, song_length_second
|
47 |
+
|
48 |
+
# Audio files with format of wav and mp3
|
49 |
+
def get_audio_paths(audio_dir):
|
50 |
+
return [os.path.join(root, fname) for (root, dir_names, file_names) in os.walk(audio_dir, followlinks=True)
|
51 |
+
for fname in file_names if (fname.lower().endswith('.wav') or fname.lower().endswith('.mp3'))]
|
52 |
+
|
53 |
+
def get_lab_paths(lab_dir):
|
54 |
+
return [os.path.join(root, fname) for (root, dir_names, file_names) in os.walk(lab_dir, followlinks=True)
|
55 |
+
for fname in file_names if (fname.lower().endswith('.lab'))]
|
56 |
+
|
57 |
+
|
58 |
+
class metrics():
|
59 |
+
def __init__(self):
|
60 |
+
super(metrics, self).__init__()
|
61 |
+
self.score_metrics = ['root', 'thirds', 'triads', 'sevenths', 'tetrads', 'majmin', 'mirex']
|
62 |
+
self.score_list_dict = dict()
|
63 |
+
for i in self.score_metrics:
|
64 |
+
self.score_list_dict[i] = list()
|
65 |
+
self.average_score = dict()
|
66 |
+
|
67 |
+
def score(self, metric, gt_path, est_path):
|
68 |
+
if metric == 'root':
|
69 |
+
score = self.root_score(gt_path,est_path)
|
70 |
+
elif metric == 'thirds':
|
71 |
+
score = self.thirds_score(gt_path,est_path)
|
72 |
+
elif metric == 'triads':
|
73 |
+
score = self.triads_score(gt_path,est_path)
|
74 |
+
elif metric == 'sevenths':
|
75 |
+
score = self.sevenths_score(gt_path,est_path)
|
76 |
+
elif metric == 'tetrads':
|
77 |
+
score = self.tetrads_score(gt_path,est_path)
|
78 |
+
elif metric == 'majmin':
|
79 |
+
score = self.majmin_score(gt_path,est_path)
|
80 |
+
elif metric == 'mirex':
|
81 |
+
score = self.mirex_score(gt_path,est_path)
|
82 |
+
else:
|
83 |
+
raise NotImplementedError
|
84 |
+
return score
|
85 |
+
|
86 |
+
def root_score(self, gt_path, est_path):
|
87 |
+
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
|
88 |
+
ref_labels = lab_file_error_modify(ref_labels)
|
89 |
+
(est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
|
90 |
+
est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
|
91 |
+
ref_intervals.max(), mir_eval.chord.NO_CHORD,
|
92 |
+
mir_eval.chord.NO_CHORD)
|
93 |
+
(intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
|
94 |
+
est_intervals, est_labels)
|
95 |
+
durations = mir_eval.util.intervals_to_durations(intervals)
|
96 |
+
comparisons = mir_eval.chord.root(ref_labels, est_labels)
|
97 |
+
score = mir_eval.chord.weighted_accuracy(comparisons, durations)
|
98 |
+
return score
|
99 |
+
|
100 |
+
def thirds_score(self, gt_path, est_path):
|
101 |
+
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
|
102 |
+
ref_labels = lab_file_error_modify(ref_labels)
|
103 |
+
(est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
|
104 |
+
est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
|
105 |
+
ref_intervals.max(), mir_eval.chord.NO_CHORD,
|
106 |
+
mir_eval.chord.NO_CHORD)
|
107 |
+
(intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
|
108 |
+
est_intervals, est_labels)
|
109 |
+
durations = mir_eval.util.intervals_to_durations(intervals)
|
110 |
+
comparisons = mir_eval.chord.thirds(ref_labels, est_labels)
|
111 |
+
score = mir_eval.chord.weighted_accuracy(comparisons, durations)
|
112 |
+
return score
|
113 |
+
|
114 |
+
def triads_score(self, gt_path, est_path):
|
115 |
+
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
|
116 |
+
ref_labels = lab_file_error_modify(ref_labels)
|
117 |
+
(est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
|
118 |
+
est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
|
119 |
+
ref_intervals.max(), mir_eval.chord.NO_CHORD,
|
120 |
+
mir_eval.chord.NO_CHORD)
|
121 |
+
(intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
|
122 |
+
est_intervals, est_labels)
|
123 |
+
durations = mir_eval.util.intervals_to_durations(intervals)
|
124 |
+
comparisons = mir_eval.chord.triads(ref_labels, est_labels)
|
125 |
+
score = mir_eval.chord.weighted_accuracy(comparisons, durations)
|
126 |
+
return score
|
127 |
+
|
128 |
+
def sevenths_score(self, gt_path, est_path):
|
129 |
+
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
|
130 |
+
ref_labels = lab_file_error_modify(ref_labels)
|
131 |
+
(est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
|
132 |
+
est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
|
133 |
+
ref_intervals.max(), mir_eval.chord.NO_CHORD,
|
134 |
+
mir_eval.chord.NO_CHORD)
|
135 |
+
(intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
|
136 |
+
est_intervals, est_labels)
|
137 |
+
durations = mir_eval.util.intervals_to_durations(intervals)
|
138 |
+
comparisons = mir_eval.chord.sevenths(ref_labels, est_labels)
|
139 |
+
score = mir_eval.chord.weighted_accuracy(comparisons, durations)
|
140 |
+
return score
|
141 |
+
|
142 |
+
def tetrads_score(self, gt_path, est_path):
|
143 |
+
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
|
144 |
+
ref_labels = lab_file_error_modify(ref_labels)
|
145 |
+
(est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
|
146 |
+
est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
|
147 |
+
ref_intervals.max(), mir_eval.chord.NO_CHORD,
|
148 |
+
mir_eval.chord.NO_CHORD)
|
149 |
+
(intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
|
150 |
+
est_intervals, est_labels)
|
151 |
+
durations = mir_eval.util.intervals_to_durations(intervals)
|
152 |
+
comparisons = mir_eval.chord.tetrads(ref_labels, est_labels)
|
153 |
+
score = mir_eval.chord.weighted_accuracy(comparisons, durations)
|
154 |
+
return score
|
155 |
+
|
156 |
+
def majmin_score(self, gt_path, est_path):
|
157 |
+
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
|
158 |
+
ref_labels = lab_file_error_modify(ref_labels)
|
159 |
+
(est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
|
160 |
+
est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
|
161 |
+
ref_intervals.max(), mir_eval.chord.NO_CHORD,
|
162 |
+
mir_eval.chord.NO_CHORD)
|
163 |
+
(intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
|
164 |
+
est_intervals, est_labels)
|
165 |
+
durations = mir_eval.util.intervals_to_durations(intervals)
|
166 |
+
comparisons = mir_eval.chord.majmin(ref_labels, est_labels)
|
167 |
+
score = mir_eval.chord.weighted_accuracy(comparisons, durations)
|
168 |
+
return score
|
169 |
+
|
170 |
+
def mirex_score(self, gt_path, est_path):
|
171 |
+
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
|
172 |
+
ref_labels = lab_file_error_modify(ref_labels)
|
173 |
+
(est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
|
174 |
+
est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
|
175 |
+
ref_intervals.max(), mir_eval.chord.NO_CHORD,
|
176 |
+
mir_eval.chord.NO_CHORD)
|
177 |
+
(intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
|
178 |
+
est_intervals, est_labels)
|
179 |
+
durations = mir_eval.util.intervals_to_durations(intervals)
|
180 |
+
comparisons = mir_eval.chord.mirex(ref_labels, est_labels)
|
181 |
+
score = mir_eval.chord.weighted_accuracy(comparisons, durations)
|
182 |
+
return score
|
183 |
+
|
184 |
+
def lab_file_error_modify(ref_labels):
|
185 |
+
for i in range(len(ref_labels)):
|
186 |
+
if ref_labels[i][-2:] == ':4':
|
187 |
+
ref_labels[i] = ref_labels[i].replace(':4', ':sus4')
|
188 |
+
elif ref_labels[i][-2:] == ':6':
|
189 |
+
ref_labels[i] = ref_labels[i].replace(':6', ':maj6')
|
190 |
+
elif ref_labels[i][-4:] == ':6/2':
|
191 |
+
ref_labels[i] = ref_labels[i].replace(':6/2', ':maj6/2')
|
192 |
+
elif ref_labels[i] == 'Emin/4':
|
193 |
+
ref_labels[i] = 'E:min/4'
|
194 |
+
elif ref_labels[i] == 'A7/3':
|
195 |
+
ref_labels[i] = 'A:7/3'
|
196 |
+
elif ref_labels[i] == 'Bb7/3':
|
197 |
+
ref_labels[i] = 'Bb:7/3'
|
198 |
+
elif ref_labels[i] == 'Bb7/5':
|
199 |
+
ref_labels[i] = 'Bb:7/5'
|
200 |
+
elif ref_labels[i].find(':') == -1:
|
201 |
+
if ref_labels[i].find('min') != -1:
|
202 |
+
ref_labels[i] = ref_labels[i][:ref_labels[i].find('min')] + ':' + ref_labels[i][ref_labels[i].find('min'):]
|
203 |
+
return ref_labels
|
204 |
+
|
205 |
+
def root_majmin_score_calculation(valid_dataset, config, mean, std, device, model, model_type, verbose=False):
|
206 |
+
valid_song_names = valid_dataset.song_names
|
207 |
+
paths = valid_dataset.preprocessor.get_all_files()
|
208 |
+
|
209 |
+
metrics_ = metrics()
|
210 |
+
song_length_list = list()
|
211 |
+
for path in paths:
|
212 |
+
song_name, lab_file_path, mp3_file_path, _ = path
|
213 |
+
if not song_name in valid_song_names:
|
214 |
+
continue
|
215 |
+
try:
|
216 |
+
n_timestep = config.model['timestep']
|
217 |
+
feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)
|
218 |
+
feature = feature.T
|
219 |
+
feature = (feature - mean) / std
|
220 |
+
time_unit = feature_per_second
|
221 |
+
|
222 |
+
num_pad = n_timestep - (feature.shape[0] % n_timestep)
|
223 |
+
feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
|
224 |
+
num_instance = feature.shape[0] // n_timestep
|
225 |
+
|
226 |
+
start_time = 0.0
|
227 |
+
lines = []
|
228 |
+
with torch.no_grad():
|
229 |
+
model.eval()
|
230 |
+
feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)
|
231 |
+
for t in range(num_instance):
|
232 |
+
if model_type == 'btc':
|
233 |
+
encoder_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])
|
234 |
+
prediction, _ = model.output_layer(encoder_output)
|
235 |
+
prediction = prediction.squeeze()
|
236 |
+
elif model_type == 'cnn' or model_type =='crnn':
|
237 |
+
prediction, _, _, _ = model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
|
238 |
+
for i in range(n_timestep):
|
239 |
+
if t == 0 and i == 0:
|
240 |
+
prev_chord = prediction[i].item()
|
241 |
+
continue
|
242 |
+
if prediction[i].item() != prev_chord:
|
243 |
+
lines.append(
|
244 |
+
'%.6f %.6f %s\n' % (
|
245 |
+
start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))
|
246 |
+
start_time = time_unit * (n_timestep * t + i)
|
247 |
+
prev_chord = prediction[i].item()
|
248 |
+
if t == num_instance - 1 and i + num_pad == n_timestep:
|
249 |
+
if start_time != time_unit * (n_timestep * t + i):
|
250 |
+
lines.append(
|
251 |
+
'%.6f %.6f %s\n' % (
|
252 |
+
start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))
|
253 |
+
break
|
254 |
+
pid = os.getpid()
|
255 |
+
tmp_path = 'tmp_' + str(pid) + '.lab'
|
256 |
+
with open(tmp_path, 'w') as f:
|
257 |
+
for line in lines:
|
258 |
+
f.write(line)
|
259 |
+
|
260 |
+
root_majmin = ['root', 'majmin']
|
261 |
+
for m in root_majmin:
|
262 |
+
metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))
|
263 |
+
song_length_list.append(song_length_second)
|
264 |
+
if verbose:
|
265 |
+
for m in root_majmin:
|
266 |
+
print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))
|
267 |
+
except:
|
268 |
+
print('song name %s\' lab file error' % song_name)
|
269 |
+
|
270 |
+
tmp = song_length_list / np.sum(song_length_list)
|
271 |
+
for m in root_majmin:
|
272 |
+
metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))
|
273 |
+
|
274 |
+
return metrics_.score_list_dict, song_length_list, metrics_.average_score
|
275 |
+
|
276 |
+
def root_majmin_score_calculation_crf(valid_dataset, config, mean, std, device, pre_model, model, model_type, verbose=False):
|
277 |
+
valid_song_names = valid_dataset.song_names
|
278 |
+
paths = valid_dataset.preprocessor.get_all_files()
|
279 |
+
|
280 |
+
metrics_ = metrics()
|
281 |
+
song_length_list = list()
|
282 |
+
for path in paths:
|
283 |
+
song_name, lab_file_path, mp3_file_path, _ = path
|
284 |
+
if not song_name in valid_song_names:
|
285 |
+
continue
|
286 |
+
try:
|
287 |
+
n_timestep = config.model['timestep']
|
288 |
+
feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)
|
289 |
+
feature = feature.T
|
290 |
+
feature = (feature - mean) / std
|
291 |
+
time_unit = feature_per_second
|
292 |
+
|
293 |
+
num_pad = n_timestep - (feature.shape[0] % n_timestep)
|
294 |
+
feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
|
295 |
+
num_instance = feature.shape[0] // n_timestep
|
296 |
+
|
297 |
+
start_time = 0.0
|
298 |
+
lines = []
|
299 |
+
with torch.no_grad():
|
300 |
+
model.eval()
|
301 |
+
feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)
|
302 |
+
for t in range(num_instance):
|
303 |
+
if (model_type == 'cnn') or (model_type == 'crnn') or (model_type == 'btc'):
|
304 |
+
logits = pre_model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
|
305 |
+
prediction, _ = model(logits, torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
|
306 |
+
else:
|
307 |
+
raise NotImplementedError
|
308 |
+
for i in range(n_timestep):
|
309 |
+
if t == 0 and i == 0:
|
310 |
+
prev_chord = prediction[i].item()
|
311 |
+
continue
|
312 |
+
if prediction[i].item() != prev_chord:
|
313 |
+
lines.append(
|
314 |
+
'%.6f %.6f %s\n' % (
|
315 |
+
start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))
|
316 |
+
start_time = time_unit * (n_timestep * t + i)
|
317 |
+
prev_chord = prediction[i].item()
|
318 |
+
if t == num_instance - 1 and i + num_pad == n_timestep:
|
319 |
+
if start_time != time_unit * (n_timestep * t + i):
|
320 |
+
lines.append(
|
321 |
+
'%.6f %.6f %s\n' % (
|
322 |
+
start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))
|
323 |
+
break
|
324 |
+
pid = os.getpid()
|
325 |
+
tmp_path = 'tmp_' + str(pid) + '.lab'
|
326 |
+
with open(tmp_path, 'w') as f:
|
327 |
+
for line in lines:
|
328 |
+
f.write(line)
|
329 |
+
|
330 |
+
root_majmin = ['root', 'majmin']
|
331 |
+
for m in root_majmin:
|
332 |
+
metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))
|
333 |
+
song_length_list.append(song_length_second)
|
334 |
+
if verbose:
|
335 |
+
for m in root_majmin:
|
336 |
+
print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))
|
337 |
+
except:
|
338 |
+
print('song name %s\' lab file error' % song_name)
|
339 |
+
|
340 |
+
tmp = song_length_list / np.sum(song_length_list)
|
341 |
+
for m in root_majmin:
|
342 |
+
metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))
|
343 |
+
|
344 |
+
return metrics_.score_list_dict, song_length_list, metrics_.average_score
|
345 |
+
|
346 |
+
|
347 |
+
def large_voca_score_calculation(valid_dataset, config, mean, std, device, model, model_type, verbose=False):
|
348 |
+
idx2voca = idx2voca_chord()
|
349 |
+
valid_song_names = valid_dataset.song_names
|
350 |
+
paths = valid_dataset.preprocessor.get_all_files()
|
351 |
+
|
352 |
+
metrics_ = metrics()
|
353 |
+
song_length_list = list()
|
354 |
+
for path in paths:
|
355 |
+
song_name, lab_file_path, mp3_file_path, _ = path
|
356 |
+
if not song_name in valid_song_names:
|
357 |
+
continue
|
358 |
+
try:
|
359 |
+
n_timestep = config.model['timestep']
|
360 |
+
feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)
|
361 |
+
feature = feature.T
|
362 |
+
feature = (feature - mean) / std
|
363 |
+
time_unit = feature_per_second
|
364 |
+
|
365 |
+
num_pad = n_timestep - (feature.shape[0] % n_timestep)
|
366 |
+
feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
|
367 |
+
num_instance = feature.shape[0] // n_timestep
|
368 |
+
|
369 |
+
start_time = 0.0
|
370 |
+
lines = []
|
371 |
+
with torch.no_grad():
|
372 |
+
model.eval()
|
373 |
+
feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)
|
374 |
+
for t in range(num_instance):
|
375 |
+
if model_type == 'btc':
|
376 |
+
encoder_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])
|
377 |
+
prediction, _ = model.output_layer(encoder_output)
|
378 |
+
prediction = prediction.squeeze()
|
379 |
+
elif model_type == 'cnn' or model_type =='crnn':
|
380 |
+
prediction, _, _, _ = model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
|
381 |
+
for i in range(n_timestep):
|
382 |
+
if t == 0 and i == 0:
|
383 |
+
prev_chord = prediction[i].item()
|
384 |
+
continue
|
385 |
+
if prediction[i].item() != prev_chord:
|
386 |
+
lines.append(
|
387 |
+
'%.6f %.6f %s\n' % (
|
388 |
+
start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))
|
389 |
+
start_time = time_unit * (n_timestep * t + i)
|
390 |
+
prev_chord = prediction[i].item()
|
391 |
+
if t == num_instance - 1 and i + num_pad == n_timestep:
|
392 |
+
if start_time != time_unit * (n_timestep * t + i):
|
393 |
+
lines.append(
|
394 |
+
'%.6f %.6f %s\n' % (
|
395 |
+
start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))
|
396 |
+
break
|
397 |
+
pid = os.getpid()
|
398 |
+
tmp_path = 'tmp_' + str(pid) + '.lab'
|
399 |
+
with open(tmp_path, 'w') as f:
|
400 |
+
for line in lines:
|
401 |
+
f.write(line)
|
402 |
+
|
403 |
+
for m in metrics_.score_metrics:
|
404 |
+
metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))
|
405 |
+
song_length_list.append(song_length_second)
|
406 |
+
if verbose:
|
407 |
+
for m in metrics_.score_metrics:
|
408 |
+
print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))
|
409 |
+
except:
|
410 |
+
print('song name %s\' lab file error' % song_name)
|
411 |
+
|
412 |
+
tmp = song_length_list / np.sum(song_length_list)
|
413 |
+
for m in metrics_.score_metrics:
|
414 |
+
metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))
|
415 |
+
|
416 |
+
return metrics_.score_list_dict, song_length_list, metrics_.average_score
|
417 |
+
|
418 |
+
def large_voca_score_calculation_crf(valid_dataset, config, mean, std, device, pre_model, model, model_type, verbose=False):
|
419 |
+
idx2voca = idx2voca_chord()
|
420 |
+
valid_song_names = valid_dataset.song_names
|
421 |
+
paths = valid_dataset.preprocessor.get_all_files()
|
422 |
+
|
423 |
+
metrics_ = metrics()
|
424 |
+
song_length_list = list()
|
425 |
+
for path in paths:
|
426 |
+
song_name, lab_file_path, mp3_file_path, _ = path
|
427 |
+
if not song_name in valid_song_names:
|
428 |
+
continue
|
429 |
+
try:
|
430 |
+
n_timestep = config.model['timestep']
|
431 |
+
feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)
|
432 |
+
feature = feature.T
|
433 |
+
feature = (feature - mean) / std
|
434 |
+
time_unit = feature_per_second
|
435 |
+
|
436 |
+
num_pad = n_timestep - (feature.shape[0] % n_timestep)
|
437 |
+
feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
|
438 |
+
num_instance = feature.shape[0] // n_timestep
|
439 |
+
|
440 |
+
start_time = 0.0
|
441 |
+
lines = []
|
442 |
+
with torch.no_grad():
|
443 |
+
model.eval()
|
444 |
+
feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)
|
445 |
+
for t in range(num_instance):
|
446 |
+
if (model_type == 'cnn') or (model_type == 'crnn') or (model_type == 'btc'):
|
447 |
+
logits = pre_model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
|
448 |
+
prediction, _ = model(logits, torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
|
449 |
+
else:
|
450 |
+
raise NotImplementedError
|
451 |
+
for i in range(n_timestep):
|
452 |
+
if t == 0 and i == 0:
|
453 |
+
prev_chord = prediction[i].item()
|
454 |
+
continue
|
455 |
+
if prediction[i].item() != prev_chord:
|
456 |
+
lines.append(
|
457 |
+
'%.6f %.6f %s\n' % (
|
458 |
+
start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))
|
459 |
+
start_time = time_unit * (n_timestep * t + i)
|
460 |
+
prev_chord = prediction[i].item()
|
461 |
+
if t == num_instance - 1 and i + num_pad == n_timestep:
|
462 |
+
if start_time != time_unit * (n_timestep * t + i):
|
463 |
+
lines.append(
|
464 |
+
'%.6f %.6f %s\n' % (
|
465 |
+
start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))
|
466 |
+
break
|
467 |
+
pid = os.getpid()
|
468 |
+
tmp_path = 'tmp_' + str(pid) + '.lab'
|
469 |
+
with open(tmp_path, 'w') as f:
|
470 |
+
for line in lines:
|
471 |
+
f.write(line)
|
472 |
+
|
473 |
+
for m in metrics_.score_metrics:
|
474 |
+
metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))
|
475 |
+
song_length_list.append(song_length_second)
|
476 |
+
if verbose:
|
477 |
+
for m in metrics_.score_metrics:
|
478 |
+
print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))
|
479 |
+
except:
|
480 |
+
print('song name %s\' lab file error' % song_name)
|
481 |
+
|
482 |
+
tmp = song_length_list / np.sum(song_length_list)
|
483 |
+
for m in metrics_.score_metrics:
|
484 |
+
metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))
|
485 |
+
|
486 |
+
return metrics_.score_list_dict, song_length_list, metrics_.average_score
|
utils/preprocess.py
ADDED
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import librosa
|
3 |
+
from utils.chords import Chords
|
4 |
+
import re
|
5 |
+
from enum import Enum
|
6 |
+
import pyrubberband as pyrb
|
7 |
+
import torch
|
8 |
+
import math
|
9 |
+
|
10 |
+
class FeatureTypes(Enum):
|
11 |
+
cqt = 'cqt'
|
12 |
+
|
13 |
+
class Preprocess():
|
14 |
+
def __init__(self, config, feature_to_use, dataset_names, root_dir):
|
15 |
+
self.config = config
|
16 |
+
self.dataset_names = dataset_names
|
17 |
+
self.root_path = root_dir + '/'
|
18 |
+
|
19 |
+
self.time_interval = config.feature["hop_length"]/config.mp3["song_hz"]
|
20 |
+
self.no_of_chord_datapoints_per_sequence = math.ceil(config.mp3['inst_len'] / self.time_interval)
|
21 |
+
self.Chord_class = Chords()
|
22 |
+
|
23 |
+
# isophonic
|
24 |
+
self.isophonic_directory = self.root_path + 'isophonic/'
|
25 |
+
|
26 |
+
# uspop
|
27 |
+
self.uspop_directory = self.root_path + 'uspop/'
|
28 |
+
self.uspop_audio_path = 'audio/'
|
29 |
+
self.uspop_lab_path = 'annotations/uspopLabels/'
|
30 |
+
self.uspop_index_path = 'annotations/uspopLabels.txt'
|
31 |
+
|
32 |
+
# robbie williams
|
33 |
+
self.robbie_williams_directory = self.root_path + 'robbiewilliams/'
|
34 |
+
self.robbie_williams_audio_path = 'audio/'
|
35 |
+
self.robbie_williams_lab_path = 'chords/'
|
36 |
+
|
37 |
+
self.feature_name = feature_to_use
|
38 |
+
self.is_cut_last_chord = False
|
39 |
+
|
40 |
+
def find_mp3_path(self, dirpath, word):
|
41 |
+
for filename in os.listdir(dirpath):
|
42 |
+
last_dir = dirpath.split("/")[-2]
|
43 |
+
if ".mp3" in filename:
|
44 |
+
tmp = filename.replace(".mp3", "")
|
45 |
+
tmp = tmp.replace(last_dir, "")
|
46 |
+
filename_lower = tmp.lower()
|
47 |
+
filename_lower = " ".join(re.findall("[a-zA-Z]+", filename_lower))
|
48 |
+
if word.lower().replace(" ", "") in filename_lower.replace(" ", ""):
|
49 |
+
return filename
|
50 |
+
|
51 |
+
def find_mp3_path_robbiewilliams(self, dirpath, word):
|
52 |
+
for filename in os.listdir(dirpath):
|
53 |
+
if ".mp3" in filename:
|
54 |
+
tmp = filename.replace(".mp3", "")
|
55 |
+
filename_lower = tmp.lower()
|
56 |
+
filename_lower = filename_lower.replace("robbie williams", "")
|
57 |
+
filename_lower = " ".join(re.findall("[a-zA-Z]+", filename_lower))
|
58 |
+
filename_lower = self.song_pre(filename_lower)
|
59 |
+
if self.song_pre(word.lower()).replace(" ", "") in filename_lower.replace(" ", ""):
|
60 |
+
return filename
|
61 |
+
|
62 |
+
def get_all_files(self):
|
63 |
+
res_list = []
|
64 |
+
|
65 |
+
# isophonic
|
66 |
+
if "isophonic" in self.dataset_names:
|
67 |
+
for dirpath, dirnames, filenames in os.walk(self.isophonic_directory):
|
68 |
+
if not dirnames:
|
69 |
+
for filename in filenames:
|
70 |
+
if ".lab" in filename:
|
71 |
+
tmp = filename.replace(".lab", "")
|
72 |
+
song_name = " ".join(re.findall("[a-zA-Z]+", tmp)).replace("CD", "")
|
73 |
+
mp3_path = self.find_mp3_path(dirpath, song_name)
|
74 |
+
res_list.append([song_name, os.path.join(dirpath, filename), os.path.join(dirpath, mp3_path),
|
75 |
+
os.path.join(self.root_path, "result", "isophonic")])
|
76 |
+
|
77 |
+
# uspop
|
78 |
+
if "uspop" in self.dataset_names:
|
79 |
+
with open(os.path.join(self.uspop_directory, self.uspop_index_path)) as f:
|
80 |
+
uspop_lab_list = f.readlines()
|
81 |
+
uspop_lab_list = [x.strip() for x in uspop_lab_list]
|
82 |
+
|
83 |
+
for lab_path in uspop_lab_list:
|
84 |
+
spl = lab_path.split('/')
|
85 |
+
lab_artist = self.uspop_pre(spl[2])
|
86 |
+
lab_title = self.uspop_pre(spl[4][3:-4])
|
87 |
+
lab_path = lab_path.replace('./uspopLabels/', '')
|
88 |
+
lab_path = os.path.join(self.uspop_directory, self.uspop_lab_path, lab_path)
|
89 |
+
|
90 |
+
for filename in os.listdir(os.path.join(self.uspop_directory, self.uspop_audio_path)):
|
91 |
+
if not '.csv' in filename:
|
92 |
+
spl = filename.split('-')
|
93 |
+
mp3_artist = self.uspop_pre(spl[0])
|
94 |
+
mp3_title = self.uspop_pre(spl[1][:-4])
|
95 |
+
|
96 |
+
if lab_artist == mp3_artist and lab_title == mp3_title:
|
97 |
+
res_list.append([mp3_artist + mp3_title, lab_path,
|
98 |
+
os.path.join(self.uspop_directory, self.uspop_audio_path, filename),
|
99 |
+
os.path.join(self.root_path, "result", "uspop")])
|
100 |
+
break
|
101 |
+
|
102 |
+
# robbie williams
|
103 |
+
if "robbiewilliams" in self.dataset_names:
|
104 |
+
for dirpath, dirnames, filenames in os.walk(self.robbie_williams_directory):
|
105 |
+
if not dirnames:
|
106 |
+
for filename in filenames:
|
107 |
+
if ".txt" in filename and (not 'README' in filename):
|
108 |
+
tmp = filename.replace(".txt", "")
|
109 |
+
song_name = " ".join(re.findall("[a-zA-Z]+", tmp)).replace("GTChords", "")
|
110 |
+
mp3_dir = dirpath.replace("chords", "audio")
|
111 |
+
mp3_path = self.find_mp3_path_robbiewilliams(mp3_dir, song_name)
|
112 |
+
res_list.append([song_name, os.path.join(dirpath, filename), os.path.join(mp3_dir, mp3_path),
|
113 |
+
os.path.join(self.root_path, "result", "robbiewilliams")])
|
114 |
+
return res_list
|
115 |
+
|
116 |
+
def uspop_pre(self, text):
|
117 |
+
text = text.lower()
|
118 |
+
text = text.replace('_', '')
|
119 |
+
text = text.replace(' ', '')
|
120 |
+
text = " ".join(re.findall("[a-zA-Z]+", text))
|
121 |
+
return text
|
122 |
+
|
123 |
+
def song_pre(self, text):
|
124 |
+
to_remove = ["'", '`', '(', ')', ' ', '&', 'and', 'And']
|
125 |
+
|
126 |
+
for remove in to_remove:
|
127 |
+
text = text.replace(remove, '')
|
128 |
+
|
129 |
+
return text
|
130 |
+
|
131 |
+
def config_to_folder(self):
|
132 |
+
mp3_config = self.config.mp3
|
133 |
+
feature_config = self.config.feature
|
134 |
+
mp3_string = "%d_%.1f_%.1f" % \
|
135 |
+
(mp3_config['song_hz'], mp3_config['inst_len'],
|
136 |
+
mp3_config['skip_interval'])
|
137 |
+
feature_string = "%s_%d_%d_%d" % \
|
138 |
+
(self.feature_name.value, feature_config['n_bins'], feature_config['bins_per_octave'], feature_config['hop_length'])
|
139 |
+
|
140 |
+
return mp3_config, feature_config, mp3_string, feature_string
|
141 |
+
|
142 |
+
def generate_labels_features_new(self, all_list):
|
143 |
+
pid = os.getpid()
|
144 |
+
mp3_config, feature_config, mp3_str, feature_str = self.config_to_folder()
|
145 |
+
|
146 |
+
i = 0 # number of songs
|
147 |
+
j = 0 # number of impossible songs
|
148 |
+
k = 0 # number of tried songs
|
149 |
+
total = 0 # number of generated instances
|
150 |
+
|
151 |
+
stretch_factors = [1.0]
|
152 |
+
shift_factors = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6]
|
153 |
+
|
154 |
+
loop_broken = False
|
155 |
+
for song_name, lab_path, mp3_path, save_path in all_list:
|
156 |
+
|
157 |
+
# different song initialization
|
158 |
+
if loop_broken:
|
159 |
+
loop_broken = False
|
160 |
+
|
161 |
+
i += 1
|
162 |
+
print(pid, "generating features from ...", os.path.join(mp3_path))
|
163 |
+
if i % 10 == 0:
|
164 |
+
print(i, ' th song')
|
165 |
+
|
166 |
+
original_wav, sr = librosa.load(os.path.join(mp3_path), sr=mp3_config['song_hz'])
|
167 |
+
|
168 |
+
# make result path if not exists
|
169 |
+
# save_path, mp3_string, feature_string, song_name, aug.pt
|
170 |
+
result_path = os.path.join(save_path, mp3_str, feature_str, song_name.strip())
|
171 |
+
if not os.path.exists(result_path):
|
172 |
+
os.makedirs(result_path)
|
173 |
+
|
174 |
+
# calculate result
|
175 |
+
for stretch_factor in stretch_factors:
|
176 |
+
if loop_broken:
|
177 |
+
loop_broken = False
|
178 |
+
break
|
179 |
+
|
180 |
+
for shift_factor in shift_factors:
|
181 |
+
# for filename
|
182 |
+
idx = 0
|
183 |
+
|
184 |
+
chord_info = self.Chord_class.get_converted_chord(os.path.join(lab_path))
|
185 |
+
|
186 |
+
k += 1
|
187 |
+
# stretch original sound and chord info
|
188 |
+
x = pyrb.time_stretch(original_wav, sr, stretch_factor)
|
189 |
+
x = pyrb.pitch_shift(x, sr, shift_factor)
|
190 |
+
audio_length = x.shape[0]
|
191 |
+
chord_info['start'] = chord_info['start'] * 1/stretch_factor
|
192 |
+
chord_info['end'] = chord_info['end'] * 1/stretch_factor
|
193 |
+
|
194 |
+
last_sec = chord_info.iloc[-1]['end']
|
195 |
+
last_sec_hz = int(last_sec * mp3_config['song_hz'])
|
196 |
+
|
197 |
+
if audio_length + mp3_config['skip_interval'] < last_sec_hz:
|
198 |
+
print('loaded song is too short :', song_name)
|
199 |
+
loop_broken = True
|
200 |
+
j += 1
|
201 |
+
break
|
202 |
+
elif audio_length > last_sec_hz:
|
203 |
+
x = x[:last_sec_hz]
|
204 |
+
|
205 |
+
origin_length = last_sec_hz
|
206 |
+
origin_length_in_sec = origin_length / mp3_config['song_hz']
|
207 |
+
|
208 |
+
current_start_second = 0
|
209 |
+
|
210 |
+
# get chord list between current_start_second and current+song_length
|
211 |
+
while current_start_second + mp3_config['inst_len'] < origin_length_in_sec:
|
212 |
+
inst_start_sec = current_start_second
|
213 |
+
curSec = current_start_second
|
214 |
+
|
215 |
+
chord_list = []
|
216 |
+
# extract chord per 1/self.time_interval
|
217 |
+
while curSec < inst_start_sec + mp3_config['inst_len']:
|
218 |
+
try:
|
219 |
+
available_chords = chord_info.loc[(chord_info['start'] <= curSec) & (
|
220 |
+
chord_info['end'] > curSec + self.time_interval)].copy()
|
221 |
+
if len(available_chords) == 0:
|
222 |
+
available_chords = chord_info.loc[((chord_info['start'] >= curSec) & (
|
223 |
+
chord_info['start'] <= curSec + self.time_interval)) | (
|
224 |
+
(chord_info['end'] >= curSec) & (
|
225 |
+
chord_info['end'] <= curSec + self.time_interval))].copy()
|
226 |
+
if len(available_chords) == 1:
|
227 |
+
chord = available_chords['chord_id'].iloc[0]
|
228 |
+
elif len(available_chords) > 1:
|
229 |
+
max_starts = available_chords.apply(lambda row: max(row['start'], curSec),
|
230 |
+
axis=1)
|
231 |
+
available_chords['max_start'] = max_starts
|
232 |
+
min_ends = available_chords.apply(
|
233 |
+
lambda row: min(row.end, curSec + self.time_interval), axis=1)
|
234 |
+
available_chords['min_end'] = min_ends
|
235 |
+
chords_lengths = available_chords['min_end'] - available_chords['max_start']
|
236 |
+
available_chords['chord_length'] = chords_lengths
|
237 |
+
chord = available_chords.ix[available_chords['chord_length'].idxmax()]['chord_id']
|
238 |
+
else:
|
239 |
+
chord = 24
|
240 |
+
except Exception as e:
|
241 |
+
chord = 24
|
242 |
+
print(e)
|
243 |
+
print(pid, "no chord")
|
244 |
+
raise RuntimeError()
|
245 |
+
finally:
|
246 |
+
# convert chord by shift factor
|
247 |
+
if chord != 24:
|
248 |
+
chord += shift_factor * 2
|
249 |
+
chord = chord % 24
|
250 |
+
|
251 |
+
chord_list.append(chord)
|
252 |
+
curSec += self.time_interval
|
253 |
+
|
254 |
+
if len(chord_list) == self.no_of_chord_datapoints_per_sequence:
|
255 |
+
try:
|
256 |
+
sequence_start_time = current_start_second
|
257 |
+
sequence_end_time = current_start_second + mp3_config['inst_len']
|
258 |
+
|
259 |
+
start_index = int(sequence_start_time * mp3_config['song_hz'])
|
260 |
+
end_index = int(sequence_end_time * mp3_config['song_hz'])
|
261 |
+
|
262 |
+
song_seq = x[start_index:end_index]
|
263 |
+
|
264 |
+
etc = '%.1f_%.1f' % (
|
265 |
+
current_start_second, current_start_second + mp3_config['inst_len'])
|
266 |
+
aug = '%.2f_%i' % (stretch_factor, shift_factor)
|
267 |
+
|
268 |
+
if self.feature_name == FeatureTypes.cqt:
|
269 |
+
# print(pid, "make feature")
|
270 |
+
feature = librosa.cqt(song_seq, sr=sr, n_bins=feature_config['n_bins'],
|
271 |
+
bins_per_octave=feature_config['bins_per_octave'],
|
272 |
+
hop_length=feature_config['hop_length'])
|
273 |
+
else:
|
274 |
+
raise NotImplementedError
|
275 |
+
|
276 |
+
if feature.shape[1] > self.no_of_chord_datapoints_per_sequence:
|
277 |
+
feature = feature[:, :self.no_of_chord_datapoints_per_sequence]
|
278 |
+
|
279 |
+
if feature.shape[1] != self.no_of_chord_datapoints_per_sequence:
|
280 |
+
print('loaded features length is too short :', song_name)
|
281 |
+
loop_broken = True
|
282 |
+
j += 1
|
283 |
+
break
|
284 |
+
|
285 |
+
result = {
|
286 |
+
'feature': feature,
|
287 |
+
'chord': chord_list,
|
288 |
+
'etc': etc
|
289 |
+
}
|
290 |
+
|
291 |
+
# save_path, mp3_string, feature_string, song_name, aug.pt
|
292 |
+
filename = aug + "_" + str(idx) + ".pt"
|
293 |
+
torch.save(result, os.path.join(result_path, filename))
|
294 |
+
idx += 1
|
295 |
+
total += 1
|
296 |
+
except Exception as e:
|
297 |
+
print(e)
|
298 |
+
print(pid, "feature error")
|
299 |
+
raise RuntimeError()
|
300 |
+
else:
|
301 |
+
print("invalid number of chord datapoints in sequence :", len(chord_list))
|
302 |
+
current_start_second += mp3_config['skip_interval']
|
303 |
+
print(pid, "total instances: %d" % total)
|
304 |
+
|
305 |
+
def generate_labels_features_voca(self, all_list):
|
306 |
+
pid = os.getpid()
|
307 |
+
mp3_config, feature_config, mp3_str, feature_str = self.config_to_folder()
|
308 |
+
|
309 |
+
i = 0 # number of songs
|
310 |
+
j = 0 # number of impossible songs
|
311 |
+
k = 0 # number of tried songs
|
312 |
+
total = 0 # number of generated instances
|
313 |
+
stretch_factors = [1.0]
|
314 |
+
shift_factors = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6]
|
315 |
+
|
316 |
+
loop_broken = False
|
317 |
+
for song_name, lab_path, mp3_path, save_path in all_list:
|
318 |
+
save_path = save_path + '_voca'
|
319 |
+
|
320 |
+
# different song initialization
|
321 |
+
if loop_broken:
|
322 |
+
loop_broken = False
|
323 |
+
|
324 |
+
i += 1
|
325 |
+
print(pid, "generating features from ...", os.path.join(mp3_path))
|
326 |
+
if i % 10 == 0:
|
327 |
+
print(i, ' th song')
|
328 |
+
|
329 |
+
original_wav, sr = librosa.load(os.path.join(mp3_path), sr=mp3_config['song_hz'])
|
330 |
+
|
331 |
+
# save_path, mp3_string, feature_string, song_name, aug.pt
|
332 |
+
result_path = os.path.join(save_path, mp3_str, feature_str, song_name.strip())
|
333 |
+
if not os.path.exists(result_path):
|
334 |
+
os.makedirs(result_path)
|
335 |
+
|
336 |
+
# calculate result
|
337 |
+
for stretch_factor in stretch_factors:
|
338 |
+
if loop_broken:
|
339 |
+
loop_broken = False
|
340 |
+
break
|
341 |
+
|
342 |
+
for shift_factor in shift_factors:
|
343 |
+
# for filename
|
344 |
+
idx = 0
|
345 |
+
|
346 |
+
try:
|
347 |
+
chord_info = self.Chord_class.get_converted_chord_voca(os.path.join(lab_path))
|
348 |
+
except Exception as e:
|
349 |
+
print(e)
|
350 |
+
print(pid, " chord lab file error : %s" % song_name)
|
351 |
+
loop_broken = True
|
352 |
+
j += 1
|
353 |
+
break
|
354 |
+
|
355 |
+
k += 1
|
356 |
+
# stretch original sound and chord info
|
357 |
+
x = pyrb.time_stretch(original_wav, sr, stretch_factor)
|
358 |
+
x = pyrb.pitch_shift(x, sr, shift_factor)
|
359 |
+
audio_length = x.shape[0]
|
360 |
+
chord_info['start'] = chord_info['start'] * 1/stretch_factor
|
361 |
+
chord_info['end'] = chord_info['end'] * 1/stretch_factor
|
362 |
+
|
363 |
+
last_sec = chord_info.iloc[-1]['end']
|
364 |
+
last_sec_hz = int(last_sec * mp3_config['song_hz'])
|
365 |
+
|
366 |
+
if audio_length + mp3_config['skip_interval'] < last_sec_hz:
|
367 |
+
print('loaded song is too short :', song_name)
|
368 |
+
loop_broken = True
|
369 |
+
j += 1
|
370 |
+
break
|
371 |
+
elif audio_length > last_sec_hz:
|
372 |
+
x = x[:last_sec_hz]
|
373 |
+
|
374 |
+
origin_length = last_sec_hz
|
375 |
+
origin_length_in_sec = origin_length / mp3_config['song_hz']
|
376 |
+
|
377 |
+
current_start_second = 0
|
378 |
+
|
379 |
+
# get chord list between current_start_second and current+song_length
|
380 |
+
while current_start_second + mp3_config['inst_len'] < origin_length_in_sec:
|
381 |
+
inst_start_sec = current_start_second
|
382 |
+
curSec = current_start_second
|
383 |
+
|
384 |
+
chord_list = []
|
385 |
+
# extract chord per 1/self.time_interval
|
386 |
+
while curSec < inst_start_sec + mp3_config['inst_len']:
|
387 |
+
try:
|
388 |
+
available_chords = chord_info.loc[(chord_info['start'] <= curSec) & (chord_info['end'] > curSec + self.time_interval)].copy()
|
389 |
+
if len(available_chords) == 0:
|
390 |
+
available_chords = chord_info.loc[((chord_info['start'] >= curSec) & (chord_info['start'] <= curSec + self.time_interval)) | ((chord_info['end'] >= curSec) & (chord_info['end'] <= curSec + self.time_interval))].copy()
|
391 |
+
|
392 |
+
if len(available_chords) == 1:
|
393 |
+
chord = available_chords['chord_id'].iloc[0]
|
394 |
+
elif len(available_chords) > 1:
|
395 |
+
max_starts = available_chords.apply(lambda row: max(row['start'], curSec),axis=1)
|
396 |
+
available_chords['max_start'] = max_starts
|
397 |
+
min_ends = available_chords.apply(lambda row: min(row.end, curSec + self.time_interval), axis=1)
|
398 |
+
available_chords['min_end'] = min_ends
|
399 |
+
chords_lengths = available_chords['min_end'] - available_chords['max_start']
|
400 |
+
available_chords['chord_length'] = chords_lengths
|
401 |
+
chord = available_chords.ix[available_chords['chord_length'].idxmax()]['chord_id']
|
402 |
+
else:
|
403 |
+
chord = 169
|
404 |
+
except Exception as e:
|
405 |
+
chord = 169
|
406 |
+
print(e)
|
407 |
+
print(pid, "no chord")
|
408 |
+
raise RuntimeError()
|
409 |
+
finally:
|
410 |
+
# convert chord by shift factor
|
411 |
+
if chord != 169 and chord != 168:
|
412 |
+
chord += shift_factor * 14
|
413 |
+
chord = chord % 168
|
414 |
+
|
415 |
+
chord_list.append(chord)
|
416 |
+
curSec += self.time_interval
|
417 |
+
|
418 |
+
if len(chord_list) == self.no_of_chord_datapoints_per_sequence:
|
419 |
+
try:
|
420 |
+
sequence_start_time = current_start_second
|
421 |
+
sequence_end_time = current_start_second + mp3_config['inst_len']
|
422 |
+
|
423 |
+
start_index = int(sequence_start_time * mp3_config['song_hz'])
|
424 |
+
end_index = int(sequence_end_time * mp3_config['song_hz'])
|
425 |
+
|
426 |
+
song_seq = x[start_index:end_index]
|
427 |
+
|
428 |
+
etc = '%.1f_%.1f' % (
|
429 |
+
current_start_second, current_start_second + mp3_config['inst_len'])
|
430 |
+
aug = '%.2f_%i' % (stretch_factor, shift_factor)
|
431 |
+
|
432 |
+
if self.feature_name == FeatureTypes.cqt:
|
433 |
+
feature = librosa.cqt(song_seq, sr=sr, n_bins=feature_config['n_bins'],
|
434 |
+
bins_per_octave=feature_config['bins_per_octave'],
|
435 |
+
hop_length=feature_config['hop_length'])
|
436 |
+
else:
|
437 |
+
raise NotImplementedError
|
438 |
+
|
439 |
+
if feature.shape[1] > self.no_of_chord_datapoints_per_sequence:
|
440 |
+
feature = feature[:, :self.no_of_chord_datapoints_per_sequence]
|
441 |
+
|
442 |
+
if feature.shape[1] != self.no_of_chord_datapoints_per_sequence:
|
443 |
+
print('loaded features length is too short :', song_name)
|
444 |
+
loop_broken = True
|
445 |
+
j += 1
|
446 |
+
break
|
447 |
+
|
448 |
+
result = {
|
449 |
+
'feature': feature,
|
450 |
+
'chord': chord_list,
|
451 |
+
'etc': etc
|
452 |
+
}
|
453 |
+
|
454 |
+
# save_path, mp3_string, feature_string, song_name, aug.pt
|
455 |
+
filename = aug + "_" + str(idx) + ".pt"
|
456 |
+
torch.save(result, os.path.join(result_path, filename))
|
457 |
+
idx += 1
|
458 |
+
total += 1
|
459 |
+
except Exception as e:
|
460 |
+
print(e)
|
461 |
+
print(pid, "feature error")
|
462 |
+
raise RuntimeError()
|
463 |
+
else:
|
464 |
+
print("invalid number of chord datapoints in sequence :", len(chord_list))
|
465 |
+
current_start_second += mp3_config['skip_interval']
|
466 |
+
print(pid, "total instances: %d" % total)
|
utils/pytorch_utils.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import math
|
6 |
+
from utils import logger
|
7 |
+
|
8 |
+
use_cuda = torch.cuda.is_available()
|
9 |
+
|
10 |
+
|
11 |
+
# optimization
|
12 |
+
# reference: http://pytorch.org/docs/master/_modules/torch/optim/lr_scheduler.html#ReduceLROnPlateau
|
13 |
+
def adjusting_learning_rate(optimizer, factor=.5, min_lr=0.00001):
|
14 |
+
for i, param_group in enumerate(optimizer.param_groups):
|
15 |
+
old_lr = float(param_group['lr'])
|
16 |
+
new_lr = max(old_lr * factor, min_lr)
|
17 |
+
param_group['lr'] = new_lr
|
18 |
+
logger.info('adjusting learning rate from %.6f to %.6f' % (old_lr, new_lr))
|
19 |
+
|
20 |
+
|
21 |
+
# model save and loading
|
22 |
+
def load_model(asset_path, model, optimizer, restore_epoch=0):
|
23 |
+
if os.path.isfile(os.path.join(asset_path, 'model', 'checkpoint_%d.pth.tar' % restore_epoch), map_location=lambda storage, loc: storage):
|
24 |
+
checkpoint = torch.load(os.path.join(asset_path, 'model', 'checkpoint_%d.pth.tar' % restore_epoch))
|
25 |
+
model.load_state_dict(checkpoint['model'])
|
26 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
27 |
+
current_step = checkpoint['current_step']
|
28 |
+
logger.info("restore model with %d epoch" % restore_epoch)
|
29 |
+
else:
|
30 |
+
logger.info("no checkpoint with %d epoch" % restore_epoch)
|
31 |
+
current_step = 0
|
32 |
+
|
33 |
+
return model, optimizer, current_step
|
utils/tf_logger.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import numpy as np
|
3 |
+
import scipy.misc
|
4 |
+
|
5 |
+
try:
|
6 |
+
from StringIO import StringIO # Python 2.7
|
7 |
+
except ImportError:
|
8 |
+
from io import BytesIO # Python 3.x
|
9 |
+
|
10 |
+
|
11 |
+
class TF_Logger(object):
|
12 |
+
def __init__(self, log_dir):
|
13 |
+
"""Create a summary writer logging to log_dir."""
|
14 |
+
self.writer = tf.summary.FileWriter(log_dir)
|
15 |
+
|
16 |
+
def scalar_summary(self, tag, value, step):
|
17 |
+
"""Log a scalar variable."""
|
18 |
+
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
|
19 |
+
self.writer.add_summary(summary, step)
|
20 |
+
|
21 |
+
def image_summary(self, tag, images, step):
|
22 |
+
"""Log a list of images."""
|
23 |
+
|
24 |
+
img_summaries = []
|
25 |
+
for i, img in enumerate(images):
|
26 |
+
# Write the image to a string
|
27 |
+
try:
|
28 |
+
s = StringIO()
|
29 |
+
except:
|
30 |
+
s = BytesIO()
|
31 |
+
scipy.misc.toimage(img).save(s, format="png")
|
32 |
+
|
33 |
+
# Create an Image object
|
34 |
+
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
|
35 |
+
height=img.shape[0],
|
36 |
+
width=img.shape[1])
|
37 |
+
# Create a Summary value
|
38 |
+
img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
|
39 |
+
|
40 |
+
# Create and write Summary
|
41 |
+
summary = tf.Summary(value=img_summaries)
|
42 |
+
self.writer.add_summary(summary, step)
|
43 |
+
|
44 |
+
def histo_summary(self, tag, values, step, bins=1000):
|
45 |
+
"""Log a histogram of the tensor of values."""
|
46 |
+
|
47 |
+
# Create a histogram using numpy
|
48 |
+
counts, bin_edges = np.histogram(values, bins=bins)
|
49 |
+
|
50 |
+
# Fill the fields of the histogram proto
|
51 |
+
hist = tf.HistogramProto()
|
52 |
+
hist.min = float(np.min(values))
|
53 |
+
hist.max = float(np.max(values))
|
54 |
+
hist.num = int(np.prod(values.shape))
|
55 |
+
hist.sum = float(np.sum(values))
|
56 |
+
hist.sum_squares = float(np.sum(values ** 2))
|
57 |
+
|
58 |
+
# Drop the start of the first bin
|
59 |
+
bin_edges = bin_edges[1:]
|
60 |
+
|
61 |
+
# Add bin edges and counts
|
62 |
+
for edge in bin_edges:
|
63 |
+
hist.bucket_limit.append(edge)
|
64 |
+
for c in counts:
|
65 |
+
hist.bucket.append(c)
|
66 |
+
|
67 |
+
# Create and write Summary
|
68 |
+
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
|
69 |
+
self.writer.add_summary(summary, step)
|
70 |
+
self.writer.flush()
|
utils/transformer_modules.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import numpy as np
|
8 |
+
import math
|
9 |
+
|
10 |
+
def _gen_bias_mask(max_length):
|
11 |
+
"""
|
12 |
+
Generates bias values (-Inf) to mask future timesteps during attention
|
13 |
+
"""
|
14 |
+
np_mask = np.triu(np.full([max_length, max_length], -np.inf), 1)
|
15 |
+
torch_mask = torch.from_numpy(np_mask).type(torch.FloatTensor)
|
16 |
+
return torch_mask.unsqueeze(0).unsqueeze(1)
|
17 |
+
|
18 |
+
def _gen_timing_signal(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
19 |
+
"""
|
20 |
+
Generates a [1, length, channels] timing signal consisting of sinusoids
|
21 |
+
Adapted from:
|
22 |
+
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py
|
23 |
+
"""
|
24 |
+
position = np.arange(length)
|
25 |
+
num_timescales = channels // 2
|
26 |
+
log_timescale_increment = (
|
27 |
+
math.log(float(max_timescale) / float(min_timescale)) /
|
28 |
+
(float(num_timescales) - 1))
|
29 |
+
inv_timescales = min_timescale * np.exp(
|
30 |
+
np.arange(num_timescales).astype(np.float64) * -log_timescale_increment)
|
31 |
+
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0)
|
32 |
+
|
33 |
+
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
|
34 |
+
signal = np.pad(signal, [[0, 0], [0, channels % 2]],
|
35 |
+
'constant', constant_values=[0.0, 0.0])
|
36 |
+
signal = signal.reshape([1, length, channels])
|
37 |
+
|
38 |
+
return torch.from_numpy(signal).type(torch.FloatTensor)
|
39 |
+
|
40 |
+
class LayerNorm(nn.Module):
|
41 |
+
# Borrowed from jekbradbury
|
42 |
+
# https://github.com/pytorch/pytorch/issues/1959
|
43 |
+
def __init__(self, features, eps=1e-6):
|
44 |
+
super(LayerNorm, self).__init__()
|
45 |
+
self.gamma = nn.Parameter(torch.ones(features))
|
46 |
+
self.beta = nn.Parameter(torch.zeros(features))
|
47 |
+
self.eps = eps
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
mean = x.mean(-1, keepdim=True)
|
51 |
+
std = x.std(-1, keepdim=True)
|
52 |
+
return self.gamma * (x - mean) / (std + self.eps) + self.beta
|
53 |
+
|
54 |
+
class OutputLayer(nn.Module):
|
55 |
+
"""
|
56 |
+
Abstract base class for output layer.
|
57 |
+
Handles projection to output labels
|
58 |
+
"""
|
59 |
+
def __init__(self, hidden_size, output_size, probs_out=False):
|
60 |
+
super(OutputLayer, self).__init__()
|
61 |
+
self.output_size = output_size
|
62 |
+
self.output_projection = nn.Linear(hidden_size, output_size)
|
63 |
+
self.probs_out = probs_out
|
64 |
+
self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=int(hidden_size/2), batch_first=True, bidirectional=True)
|
65 |
+
self.hidden_size = hidden_size
|
66 |
+
|
67 |
+
def loss(self, hidden, labels):
|
68 |
+
raise NotImplementedError('Must implement {}.loss'.format(self.__class__.__name__))
|
69 |
+
|
70 |
+
class SoftmaxOutputLayer(OutputLayer):
|
71 |
+
"""
|
72 |
+
Implements a softmax based output layer
|
73 |
+
"""
|
74 |
+
def forward(self, hidden):
|
75 |
+
logits = self.output_projection(hidden)
|
76 |
+
probs = F.softmax(logits, -1)
|
77 |
+
# _, predictions = torch.max(probs, dim=-1)
|
78 |
+
topk, indices = torch.topk(probs, 2)
|
79 |
+
predictions = indices[:,:,0]
|
80 |
+
second = indices[:,:,1]
|
81 |
+
if self.probs_out is True:
|
82 |
+
return logits
|
83 |
+
# return probs
|
84 |
+
return predictions, second
|
85 |
+
|
86 |
+
def loss(self, hidden, labels):
|
87 |
+
logits = self.output_projection(hidden)
|
88 |
+
log_probs = F.log_softmax(logits, -1)
|
89 |
+
return F.nll_loss(log_probs.view(-1, self.output_size), labels.view(-1))
|
90 |
+
|
91 |
+
class MultiHeadAttention(nn.Module):
|
92 |
+
"""
|
93 |
+
Multi-head attention as per https://arxiv.org/pdf/1706.03762.pdf
|
94 |
+
Refer Figure 2
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self, input_depth, total_key_depth, total_value_depth, output_depth,
|
98 |
+
num_heads, bias_mask=None, dropout=0.0, attention_map=False):
|
99 |
+
"""
|
100 |
+
Parameters:
|
101 |
+
input_depth: Size of last dimension of input
|
102 |
+
total_key_depth: Size of last dimension of keys. Must be divisible by num_head
|
103 |
+
total_value_depth: Size of last dimension of values. Must be divisible by num_head
|
104 |
+
output_depth: Size last dimension of the final output
|
105 |
+
num_heads: Number of attention heads
|
106 |
+
bias_mask: Masking tensor to prevent connections to future elements
|
107 |
+
dropout: Dropout probability (Should be non-zero only during training)
|
108 |
+
"""
|
109 |
+
super(MultiHeadAttention, self).__init__()
|
110 |
+
# Checks borrowed from
|
111 |
+
# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py
|
112 |
+
if total_key_depth % num_heads != 0:
|
113 |
+
raise ValueError("Key depth (%d) must be divisible by the number of "
|
114 |
+
"attention heads (%d)." % (total_key_depth, num_heads))
|
115 |
+
if total_value_depth % num_heads != 0:
|
116 |
+
raise ValueError("Value depth (%d) must be divisible by the number of "
|
117 |
+
"attention heads (%d)." % (total_value_depth, num_heads))
|
118 |
+
|
119 |
+
self.attention_map = attention_map
|
120 |
+
|
121 |
+
self.num_heads = num_heads
|
122 |
+
self.query_scale = (total_key_depth // num_heads) ** -0.5
|
123 |
+
self.bias_mask = bias_mask
|
124 |
+
|
125 |
+
# Key and query depth will be same
|
126 |
+
self.query_linear = nn.Linear(input_depth, total_key_depth, bias=False)
|
127 |
+
self.key_linear = nn.Linear(input_depth, total_key_depth, bias=False)
|
128 |
+
self.value_linear = nn.Linear(input_depth, total_value_depth, bias=False)
|
129 |
+
self.output_linear = nn.Linear(total_value_depth, output_depth, bias=False)
|
130 |
+
|
131 |
+
self.dropout = nn.Dropout(dropout)
|
132 |
+
|
133 |
+
def _split_heads(self, x):
|
134 |
+
"""
|
135 |
+
Split x such to add an extra num_heads dimension
|
136 |
+
Input:
|
137 |
+
x: a Tensor with shape [batch_size, seq_length, depth]
|
138 |
+
Returns:
|
139 |
+
A Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads]
|
140 |
+
"""
|
141 |
+
if len(x.shape) != 3:
|
142 |
+
raise ValueError("x must have rank 3")
|
143 |
+
shape = x.shape
|
144 |
+
return x.view(shape[0], shape[1], self.num_heads, shape[2] // self.num_heads).permute(0, 2, 1, 3)
|
145 |
+
|
146 |
+
def _merge_heads(self, x):
|
147 |
+
"""
|
148 |
+
Merge the extra num_heads into the last dimension
|
149 |
+
Input:
|
150 |
+
x: a Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads]
|
151 |
+
Returns:
|
152 |
+
A Tensor with shape [batch_size, seq_length, depth]
|
153 |
+
"""
|
154 |
+
if len(x.shape) != 4:
|
155 |
+
raise ValueError("x must have rank 4")
|
156 |
+
shape = x.shape
|
157 |
+
return x.permute(0, 2, 1, 3).contiguous().view(shape[0], shape[2], shape[3] * self.num_heads)
|
158 |
+
|
159 |
+
def forward(self, queries, keys, values):
|
160 |
+
|
161 |
+
# Do a linear for each component
|
162 |
+
queries = self.query_linear(queries)
|
163 |
+
keys = self.key_linear(keys)
|
164 |
+
values = self.value_linear(values)
|
165 |
+
|
166 |
+
# Split into multiple heads
|
167 |
+
queries = self._split_heads(queries)
|
168 |
+
keys = self._split_heads(keys)
|
169 |
+
values = self._split_heads(values)
|
170 |
+
|
171 |
+
# Scale queries
|
172 |
+
queries *= self.query_scale
|
173 |
+
|
174 |
+
# Combine queries and keys
|
175 |
+
logits = torch.matmul(queries, keys.permute(0, 1, 3, 2))
|
176 |
+
|
177 |
+
# Add bias to mask future values
|
178 |
+
if self.bias_mask is not None:
|
179 |
+
logits += self.bias_mask[:, :, :logits.shape[-2], :logits.shape[-1]].type_as(logits.data)
|
180 |
+
|
181 |
+
# Convert to probabilites
|
182 |
+
weights = nn.functional.softmax(logits, dim=-1)
|
183 |
+
|
184 |
+
# Dropout
|
185 |
+
weights = self.dropout(weights)
|
186 |
+
|
187 |
+
# Combine with values to get context
|
188 |
+
contexts = torch.matmul(weights, values)
|
189 |
+
|
190 |
+
# Merge heads
|
191 |
+
contexts = self._merge_heads(contexts)
|
192 |
+
# contexts = torch.tanh(contexts)
|
193 |
+
|
194 |
+
# Linear to get output
|
195 |
+
outputs = self.output_linear(contexts)
|
196 |
+
|
197 |
+
if self.attention_map is True:
|
198 |
+
return outputs, weights
|
199 |
+
|
200 |
+
return outputs
|
201 |
+
|
202 |
+
|
203 |
+
class Conv(nn.Module):
|
204 |
+
"""
|
205 |
+
Convenience class that does padding and convolution for inputs in the format
|
206 |
+
[batch_size, sequence length, hidden size]
|
207 |
+
"""
|
208 |
+
|
209 |
+
def __init__(self, input_size, output_size, kernel_size, pad_type):
|
210 |
+
"""
|
211 |
+
Parameters:
|
212 |
+
input_size: Input feature size
|
213 |
+
output_size: Output feature size
|
214 |
+
kernel_size: Kernel width
|
215 |
+
pad_type: left -> pad on the left side (to mask future data_loader),
|
216 |
+
both -> pad on both sides
|
217 |
+
"""
|
218 |
+
super(Conv, self).__init__()
|
219 |
+
padding = (kernel_size - 1, 0) if pad_type == 'left' else (kernel_size // 2, (kernel_size - 1) // 2)
|
220 |
+
self.pad = nn.ConstantPad1d(padding, 0)
|
221 |
+
self.conv = nn.Conv1d(input_size, output_size, kernel_size=kernel_size, padding=0)
|
222 |
+
|
223 |
+
def forward(self, inputs):
|
224 |
+
inputs = self.pad(inputs.permute(0, 2, 1))
|
225 |
+
outputs = self.conv(inputs).permute(0, 2, 1)
|
226 |
+
|
227 |
+
return outputs
|
228 |
+
|
229 |
+
|
230 |
+
class PositionwiseFeedForward(nn.Module):
|
231 |
+
"""
|
232 |
+
Does a Linear + RELU + Linear on each of the timesteps
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(self, input_depth, filter_size, output_depth, layer_config='ll', padding='left', dropout=0.0):
|
236 |
+
"""
|
237 |
+
Parameters:
|
238 |
+
input_depth: Size of last dimension of input
|
239 |
+
filter_size: Hidden size of the middle layer
|
240 |
+
output_depth: Size last dimension of the final output
|
241 |
+
layer_config: ll -> linear + ReLU + linear
|
242 |
+
cc -> conv + ReLU + conv etc.
|
243 |
+
padding: left -> pad on the left side (to mask future data_loader),
|
244 |
+
both -> pad on both sides
|
245 |
+
dropout: Dropout probability (Should be non-zero only during training)
|
246 |
+
"""
|
247 |
+
super(PositionwiseFeedForward, self).__init__()
|
248 |
+
|
249 |
+
layers = []
|
250 |
+
sizes = ([(input_depth, filter_size)] +
|
251 |
+
[(filter_size, filter_size)] * (len(layer_config) - 2) +
|
252 |
+
[(filter_size, output_depth)])
|
253 |
+
|
254 |
+
for lc, s in zip(list(layer_config), sizes):
|
255 |
+
if lc == 'l':
|
256 |
+
layers.append(nn.Linear(*s))
|
257 |
+
elif lc == 'c':
|
258 |
+
layers.append(Conv(*s, kernel_size=3, pad_type=padding))
|
259 |
+
else:
|
260 |
+
raise ValueError("Unknown layer type {}".format(lc))
|
261 |
+
|
262 |
+
self.layers = nn.ModuleList(layers)
|
263 |
+
self.relu = nn.ReLU()
|
264 |
+
self.dropout = nn.Dropout(dropout)
|
265 |
+
|
266 |
+
def forward(self, inputs):
|
267 |
+
x = inputs
|
268 |
+
for i, layer in enumerate(self.layers):
|
269 |
+
x = layer(x)
|
270 |
+
if i < len(self.layers):
|
271 |
+
x = self.relu(x)
|
272 |
+
x = self.dropout(x)
|
273 |
+
|
274 |
+
return x
|