Spaces:
Build error
Build error
from ..numpy_encode import * | |
import numpy as np | |
from enum import Enum | |
import torch | |
from ..vocab import * | |
from functools import partial | |
SEQType = Enum('SEQType', 'Mask, Sentence, Melody, Chords, Empty') | |
class MusicItem(): | |
def __init__(self, data, vocab, stream=None, position=None): | |
self.data = data | |
self.vocab = vocab | |
self._stream = stream | |
self._position = position | |
def __repr__(self): return '\n'.join([ | |
f'\n{self.__class__.__name__} - {self.data.shape}', | |
f'{self.vocab.textify(self.data[:10])}...']) | |
def __len__(self): return len(self.data) | |
def from_file(cls, midi_file, vocab): | |
return cls.from_stream(file2stream(midi_file), vocab) | |
def from_stream(cls, stream, vocab): | |
if not isinstance(stream, music21.stream.Score): stream = stream.voicesToParts() | |
chordarr = stream2chordarr(stream) # 2. | |
npenc = chordarr2npenc(chordarr) # 3. | |
return cls.from_npenc(npenc, vocab, stream) | |
def from_npenc(cls, npenc, vocab, stream=None): return MusicItem(npenc2idxenc(npenc, vocab), vocab, stream) | |
def from_idx(cls, item, vocab): | |
idx,pos = item | |
return MusicItem(idx, vocab=vocab, position=pos) | |
def to_idx(self): return self.data, self.position | |
def empty(cls, vocab, seq_type=SEQType.Sentence): | |
return MusicItem(seq_prefix(seq_type, vocab), vocab) | |
def stream(self): | |
self._stream = self.to_stream() if self._stream is None else self._stream | |
return self._stream | |
def to_stream(self, bpm=120): | |
return idxenc2stream(self.data, self.vocab, bpm=bpm) | |
def to_tensor(self, device=None): | |
return to_tensor(self.data, device) | |
def to_text(self, sep=' '): return self.vocab.textify(self.data, sep) | |
def position(self): | |
self._position = position_enc(self.data, self.vocab) if self._position is None else self._position | |
return self._position | |
def get_pos_tensor(self, device=None): return to_tensor(self.position, device) | |
def to_npenc(self): | |
return idxenc2npenc(self.data, self.vocab) | |
def show(self, format:str=None): | |
return self.stream.show(format) | |
def play(self): self.stream.show('midi') | |
def new(self): | |
return partial(type(self), vocab=self.vocab) | |
def trim_to_beat(self, beat, include_last_sep=False): | |
return self.new(trim_to_beat(self.data, self.position, self.vocab, beat, include_last_sep)) | |
def transpose(self, interval): | |
return self.new(tfm_transpose(self.data, interval, self.vocab), position=self._position) | |
def append(self, item): | |
return self.new(np.concatenate((self.data, item.data), axis=0)) | |
def mask_pitch(self, section=None): | |
return self.new(self.mask(self.vocab.note_range, section), position=self.position) | |
def mask_duration(self, section=None, keep_position_enc=True): | |
masked_data = self.mask(self.vocab.dur_range, section) | |
if keep_position_enc: return self.new(masked_data, position=self.position) | |
return self.new(masked_data) | |
def mask(self, token_range, section_range=None): | |
return mask_section(self.data, self.position, token_range, self.vocab.mask_idx, section_range=section_range) | |
def pad_to(self, bptt): | |
data = pad_seq(self.data, bptt, self.vocab.pad_idx) | |
pos = pad_seq(self.position, bptt, 0) | |
return self.new(data, stream=self._stream, position=pos) | |
def split_stream_parts(self): | |
self._stream = separate_melody_chord(self.stream) | |
return self.stream | |
def remove_eos(self): | |
if self.data[-1] == self.vocab.stoi[EOS]: return self.new(self.data, stream=self.stream) | |
return self | |
def split_parts(self): | |
return self.new(self.data, stream=separate_melody_chord(self.stream), position=self.position) | |
def pad_seq(seq, bptt, value): | |
pad_len = max(bptt-seq.shape[0], 0) | |
return np.pad(seq, (0, pad_len), 'constant', constant_values=value)[:bptt] | |
def to_tensor(t, device=None): | |
t = t if isinstance(t, torch.Tensor) else torch.tensor(t) | |
if device is None and torch.cuda.is_available(): t = t.cuda() | |
else: t.to(device) | |
return t.long() | |
def midi2idxenc(midi_file, vocab): | |
"Converts midi file to index encoding for training" | |
npenc = midi2npenc(midi_file) # 3. | |
return npenc2idxenc(npenc, vocab) | |
def idxenc2stream(arr, vocab, bpm=120): | |
"Converts index encoding to music21 stream" | |
npenc = idxenc2npenc(arr, vocab) | |
return npenc2stream(npenc, bpm=bpm) | |
# single stream instead of note,dur | |
def npenc2idxenc(t, vocab, seq_type=SEQType.Sentence, add_eos=False): | |
"Transforms numpy array from 2 column (note, duration) matrix to a single column" | |
"[[n1, d1], [n2, d2], ...] -> [n1, d1, n2, d2]" | |
if isinstance(t, (list, tuple)) and len(t) == 2: | |
return [npenc2idxenc(x, vocab, start_seq) for x in t] | |
t = t.copy() | |
t[:, 0] = t[:, 0] + vocab.note_range[0] | |
t[:, 1] = t[:, 1] + vocab.dur_range[0] | |
prefix = seq_prefix(seq_type, vocab) | |
suffix = np.array([vocab.stoi[EOS]]) if add_eos else np.empty(0, dtype=int) | |
return np.concatenate([prefix, t.reshape(-1), suffix]) | |
def seq_prefix(seq_type, vocab): | |
if seq_type == SEQType.Empty: return np.empty(0, dtype=int) | |
start_token = vocab.bos_idx | |
if seq_type == SEQType.Chords: start_token = vocab.stoi[CSEQ] | |
if seq_type == SEQType.Melody: start_token = vocab.stoi[MSEQ] | |
return np.array([start_token, vocab.pad_idx]) | |
def idxenc2npenc(t, vocab, validate=True): | |
if validate: t = to_valid_idxenc(t, vocab.npenc_range) | |
t = t.copy().reshape(-1, 2) | |
if t.shape[0] == 0: return t | |
t[:, 0] = t[:, 0] - vocab.note_range[0] | |
t[:, 1] = t[:, 1] - vocab.dur_range[0] | |
if validate: return to_valid_npenc(t) | |
return t | |
def to_valid_idxenc(t, valid_range): | |
r = valid_range | |
t = t[np.where((t >= r[0]) & (t < r[1]))] | |
if t.shape[-1] % 2 == 1: t = t[..., :-1] | |
return t | |
def to_valid_npenc(t): | |
is_note = (t[:, 0] < VALTSEP) | (t[:, 0] >= NOTE_SIZE) | |
invalid_note_idx = is_note.argmax() | |
invalid_dur_idx = (t[:, 1] < 0).argmax() | |
invalid_idx = max(invalid_dur_idx, invalid_note_idx) | |
if invalid_idx > 0: | |
if invalid_note_idx > 0 and invalid_dur_idx > 0: invalid_idx = min(invalid_dur_idx, invalid_note_idx) | |
print('Non midi note detected. Only returning valid portion. Index, seed', invalid_idx, t.shape) | |
return t[:invalid_idx] | |
return t | |
def position_enc(idxenc, vocab): | |
"Calculates positional beat encoding." | |
sep_idxs = (idxenc == vocab.sep_idx).nonzero()[0] | |
sep_idxs = sep_idxs[sep_idxs+2 < idxenc.shape[0]] # remove any indexes right before out of bounds (sep_idx+2) | |
dur_vals = idxenc[sep_idxs+1] | |
dur_vals[dur_vals == vocab.mask_idx] = vocab.dur_range[0] # make sure masked durations are 0 | |
dur_vals -= vocab.dur_range[0] | |
posenc = np.zeros_like(idxenc) | |
posenc[sep_idxs+2] = dur_vals | |
return posenc.cumsum() | |
def beat2index(idxenc, pos, vocab, beat, include_last_sep=False): | |
cutoff = find_beat(pos, beat) | |
if cutoff < 2: return 2 # always leave starter tokens | |
if len(idxenc) < 2 or include_last_sep: return cutoff | |
if idxenc[cutoff - 2] == vocab.sep_idx: return cutoff - 2 | |
return cutoff | |
def find_beat(pos, beat, sample_freq=SAMPLE_FREQ, side='left'): | |
return np.searchsorted(pos, beat * sample_freq, side=side) | |
# TRANSFORMS | |
def tfm_transpose(x, value, vocab): | |
x = x.copy() | |
x[(x >= vocab.note_range[0]) & (x < vocab.note_range[1])] += value | |
return x | |
def trim_to_beat(idxenc, pos, vocab, to_beat=None, include_last_sep=True): | |
if to_beat is None: return idxenc | |
cutoff = beat2index(idxenc, pos, vocab, to_beat, include_last_sep=include_last_sep) | |
return idxenc[:cutoff] | |
def mask_input(xb, mask_range, replacement_idx): | |
xb = xb.copy() | |
xb[(xb >= mask_range[0]) & (xb < mask_range[1])] = replacement_idx | |
return xb | |
def mask_section(xb, pos, token_range, replacement_idx, section_range=None): | |
xb = xb.copy() | |
token_mask = (xb >= token_range[0]) & (xb < token_range[1]) | |
if section_range is None: section_range = (None, None) | |
section_mask = np.zeros_like(xb, dtype=bool) | |
start_idx = find_beat(pos, section_range[0]) if section_range[0] is not None else 0 | |
end_idx = find_beat(pos, section_range[1]) if section_range[1] is not None else xb.shape[0] | |
section_mask[start_idx:end_idx] = True | |
xb[token_mask & section_mask] = replacement_idx | |
return xb | |