"Fastai Language Model Databunch modified to work with music" from fastai.basics import * # from fastai.basic_data import DataBunch from fastai.text.data import LMLabelList from .transform import * from ..vocab import MusicVocab class MusicDataBunch(DataBunch): "Create a `TextDataBunch` suitable for training a language model." @classmethod def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', no_check:bool=False, bs=64, val_bs:int=None, num_workers:int=0, device:torch.device=None, collate_fn:Callable=data_collate, dl_tfms:Optional[Collection[Callable]]=None, bptt:int=70, preloader_cls=None, shuffle_dl=False, transpose_range=(0,12), **kwargs) -> DataBunch: "Create a `TextDataBunch` in `path` from the `datasets` for language modelling." datasets = cls._init_ds(train_ds, valid_ds, test_ds) preloader_cls = MusicPreloader if preloader_cls is None else preloader_cls val_bs = ifnone(val_bs, bs) datasets = [preloader_cls(ds, shuffle=(i==0), bs=(bs if i==0 else val_bs), bptt=bptt, transpose_range=transpose_range, **kwargs) for i,ds in enumerate(datasets)] val_bs = bs dl_tfms = [partially_apply_vocab(tfm, train_ds.vocab) for tfm in listify(dl_tfms)] dls = [DataLoader(d, b, shuffle=shuffle_dl) for d,b in zip(datasets, (bs,val_bs,val_bs,val_bs)) if d is not None] return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check) @classmethod def from_folder(cls, path:PathOrStr, extensions='.npy', **kwargs): files = get_files(path, extensions=extensions, recurse=True); return cls.from_files(files, path, **kwargs) @classmethod def from_files(cls, files, path, processors=None, split_pct=0.1, vocab=None, list_cls=None, **kwargs): if vocab is None: vocab = MusicVocab.create() if list_cls is None: list_cls = MusicItemList src = (list_cls(items=files, path=path, processor=processors, vocab=vocab) .split_by_rand_pct(split_pct, seed=6) .label_const(label_cls=LMLabelList)) return src.databunch(**kwargs) @classmethod def empty(cls, path, **kwargs): vocab = MusicVocab.create() src = MusicItemList([], path=path, vocab=vocab, ignore_empty=True).split_none() return src.label_const(label_cls=LMLabelList).databunch() def partially_apply_vocab(tfm, vocab): if 'vocab' in inspect.getfullargspec(tfm).args: return partial(tfm, vocab=vocab) return tfm class MusicItemList(ItemList): _bunch = MusicDataBunch def __init__(self, items:Iterator, vocab:MusicVocab=None, **kwargs): super().__init__(items, **kwargs) self.vocab = vocab self.copy_new += ['vocab'] def get(self, i): o = super().get(i) if is_pos_enc(o): return MusicItem.from_idx(o, self.vocab) return MusicItem(o, self.vocab) def is_pos_enc(idxenc): if len(idxenc.shape) == 2 and idxenc.shape[0] == 2: return True return idxenc.dtype == np.object and idxenc.shape == (2,) class MusicItemProcessor(PreProcessor): "`PreProcessor` that transforms numpy files to indexes for training" def process_one(self,item): item = MusicItem.from_npenc(item, vocab=self.vocab) return item.to_idx() def process(self, ds): self.vocab = ds.vocab super().process(ds) class OpenNPFileProcessor(PreProcessor): "`PreProcessor` that opens the filenames and read the texts." def process_one(self,item): return np.load(item, allow_pickle=True) if isinstance(item, Path) else item class Midi2ItemProcessor(PreProcessor): "Skips midi preprocessing step. And encodes midi files to MusicItems" def process_one(self,item): item = MusicItem.from_file(item, vocab=self.vocab) return item.to_idx() def process(self, ds): self.vocab = ds.vocab super().process(ds) ## For npenc dataset class MusicPreloader(Callback): "Transforms the tokens in `dataset` to a stream of contiguous batches for language modelling." class CircularIndex(): "Handles shuffle, direction of indexing, wraps around to head tail in the ragged array as needed" def __init__(self, length:int, forward:bool): self.idx, self.forward = np.arange(length), forward def __getitem__(self, i): return self.idx[ i%len(self.idx) if self.forward else len(self.idx)-1-i%len(self.idx)] def __len__(self) -> int: return len(self.idx) def shuffle(self): np.random.shuffle(self.idx) def __init__(self, dataset:LabelList, lengths:Collection[int]=None, bs:int=32, bptt:int=70, backwards:bool=False, shuffle:bool=False, y_offset:int=1, transpose_range=None, transpose_p=0.5, encode_position=True, **kwargs): self.dataset,self.bs,self.bptt,self.shuffle,self.backwards,self.lengths = dataset,bs,bptt,shuffle,backwards,lengths self.vocab = self.dataset.vocab self.bs *= num_distrib() or 1 self.totalToks,self.ite_len,self.idx = int(0),None,None self.y_offset = y_offset self.transpose_range,self.transpose_p = transpose_range,transpose_p self.encode_position = encode_position self.bptt_len = self.bptt self.allocate_buffers() # needed for valid_dl on distributed training - otherwise doesn't get initialized on first epoch def __len__(self): if self.ite_len is None: if self.lengths is None: self.lengths = np.array([len(item) for item in self.dataset.x]) self.totalToks = self.lengths.sum() self.ite_len = self.bs*int( math.ceil( self.totalToks/(self.bptt*self.bs) )) if self.item is None else 1 return self.ite_len def __getattr__(self,k:str)->Any: return getattr(self.dataset, k) def allocate_buffers(self): "Create the ragged array that will be filled when we ask for items." if self.ite_len is None: len(self) self.idx = MusicPreloader.CircularIndex(len(self.dataset.x), not self.backwards) # batch shape = (bs, bptt, 2 - [index, pos]) if encode_position. Else - (bs, bptt) buffer_len = (2,) if self.encode_position else () self.batch = np.zeros((self.bs, self.bptt+self.y_offset) + buffer_len, dtype=np.int64) self.batch_x, self.batch_y = self.batch[:,0:self.bptt], self.batch[:,self.y_offset:self.bptt+self.y_offset] #ro: index of the text we're at inside our datasets for the various batches self.ro = np.zeros(self.bs, dtype=np.int64) #ri: index of the token we're at inside our current text for the various batches self.ri = np.zeros(self.bs, dtype=np.int) # allocate random transpose values. Need to allocate this before hand. self.transpose_values = self.get_random_transpose_values() def get_random_transpose_values(self): if self.transpose_range is None: return None n = len(self.dataset) rt_arr = torch.randint(*self.transpose_range, (n,))-self.transpose_range[1]//2 mask = torch.rand(rt_arr.shape) > self.transpose_p rt_arr[mask] = 0 return rt_arr def on_epoch_begin(self, **kwargs): if self.idx is None: self.allocate_buffers() elif self.shuffle: self.ite_len = None self.idx.shuffle() self.transpose_values = self.get_random_transpose_values() self.bptt_len = self.bptt self.idx.forward = not self.backwards step = self.totalToks / self.bs ln_rag, countTokens, i_rag = 0, 0, -1 for i in range(0,self.bs): #Compute the initial values for ro and ri while ln_rag + countTokens <= int(step * i): countTokens += ln_rag i_rag += 1 ln_rag = self.lengths[self.idx[i_rag]] self.ro[i] = i_rag self.ri[i] = ( ln_rag - int(step * i - countTokens) ) if self.backwards else int(step * i - countTokens) #Training dl gets on_epoch_begin called, val_dl, on_epoch_end def on_epoch_end(self, **kwargs): self.on_epoch_begin() def __getitem__(self, k:int): j = k % self.bs if j==0: if self.item is not None: return self.dataset[0] if self.idx is None: self.on_epoch_begin() self.ro[j],self.ri[j] = self.fill_row(not self.backwards, self.dataset.x, self.idx, self.batch[j][:self.bptt_len+self.y_offset], self.ro[j], self.ri[j], overlap=1, lengths=self.lengths) return self.batch_x[j][:self.bptt_len], self.batch_y[j][:self.bptt_len] def fill_row(self, forward, items, idx, row, ro, ri, overlap, lengths): "Fill the row with tokens from the ragged array. --OBS-- overlap != 1 has not been implemented" ibuf = n = 0 ro -= 1 while ibuf < row.shape[0]: ro += 1 ix = idx[ro] item = items[ix] if self.transpose_values is not None: item = item.transpose(self.transpose_values[ix].item()) if self.encode_position: # Positions are colomn stacked with indexes. This makes it easier to keep in sync rag = np.stack([item.data, item.position], axis=1) else: rag = item.data if forward: ri = 0 if ibuf else ri n = min(lengths[ix] - ri, row.shape[0] - ibuf) row[ibuf:ibuf+n] = rag[ri:ri+n] else: ri = lengths[ix] if ibuf else ri n = min(ri, row.size - ibuf) row[ibuf:ibuf+n] = rag[ri-n:ri][::-1] ibuf += n return ro, ri + ((n-overlap) if forward else -(n-overlap)) def batch_position_tfm(b): "Batch transform for training with positional encoding" x,y = b x = { 'x': x[...,0], 'pos': x[...,1] } return x, y[...,0]