# https://github.com/LAION-AI/CLAP/blob/df65ca0f6c3062dc554132cb40e74f4915084b21/src/training/data.py#L469 from functools import partial import soundfile as sf import io import numpy as np import torch import torchaudio import torchvision import torch.nn.functional as F AUDIO_CFG = { "sample_rate": 48000, "audio_length": 1024, "clip_samples": 480000, "mel_bins": 64, "window_size": 1024, "hop_size": 480, "fmin": 50, "fmax": 14000, "class_num": 527, } class dotdict(dict): """dot.notation access to dictionary attributes""" __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ class Map(dict): """ Example: m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer']) """ def __init__(self, *args, **kwargs): super(Map, self).__init__(*args, **kwargs) for arg in args: if isinstance(arg, dict): for k, v in arg.iteritems(): self[k] = v if kwargs: for k, v in kwargs.iteritems(): self[k] = v def __getattr__(self, attr): return self.get(attr) def __setattr__(self, key, value): self.__setitem__(key, value) def __setitem__(self, key, value): super(Map, self).__setitem__(key, value) self.__dict__.update({key: value}) def __delattr__(self, item): self.__delitem__(item) def __delitem__(self, key): super(Map, self).__delitem__(key) del self.__dict__[key] def int16_to_float32(x): return (x / 32767.0).astype(np.float32) def float32_to_int16(x): x = np.clip(x, a_min=-1., a_max=1.) return (x * 32767.).astype(np.int16) def get_mel(audio_data,audio_cfg): # mel shape: (n_mels, T) mel = torchaudio.transforms.MelSpectrogram( sample_rate=audio_cfg['sample_rate'], n_fft=audio_cfg['window_size'], win_length=audio_cfg['window_size'], hop_length=audio_cfg['hop_size'], center=True, pad_mode="reflect", power=2.0, norm=None, onesided=True, n_mels=audio_cfg['mel_bins'], f_min=audio_cfg['fmin'], f_max=audio_cfg['fmax'] )(audio_data) # we use log mel spectrogram as input mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) return mel.T # (T, n_mels) def get_audio_features(sample, audio_data, max_len, data_truncating, data_filling, audio_cfg): """ Calculate and add audio features to sample. Sample: a dict containing all the data of current sample. audio_data: a tensor of shape (T) containing audio data. max_len: the maximum length of audio data. data_truncating: the method of truncating data. data_filling: the method of filling data. audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg']. """ with torch.no_grad(): if len(audio_data) > max_len: if data_truncating == "rand_trunc": longer = torch.tensor([True]) elif data_truncating == "fusion": # fusion mel = get_mel(audio_data, audio_cfg) # split to three parts chunk_frames = max_len // audio_cfg['hop_size']+1 # the +1 related to how the spectrogram is computed total_frames = mel.shape[0] if chunk_frames == total_frames: # there is a corner case where the audio length is # larger than max_len but smaller than max_len+hop_size. # In this case, we just use the whole audio. mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) sample["mel_fusion"] = mel_fusion longer = torch.tensor([False]) else: ranges = np.array_split(list(range(0, total_frames-chunk_frames+1)), 3) # print('total_frames-chunk_frames:', total_frames-chunk_frames, # 'len(audio_data):', len(audio_data), # 'chunk_frames:', chunk_frames, # 'total_frames:', total_frames) if len(ranges[1]) == 0: # if the audio is too short, we just use the first chunk ranges[1] = [0] if len(ranges[2]) == 0: # if the audio is too short, we just use the first chunk ranges[2] = [0] # randomly choose index for each part idx_front = np.random.choice(ranges[0]) idx_middle = np.random.choice(ranges[1]) idx_back = np.random.choice(ranges[2]) # select mel mel_chunk_front = mel[idx_front:idx_front+chunk_frames, :] mel_chunk_middle = mel[idx_middle:idx_middle+chunk_frames, :] mel_chunk_back = mel[idx_back:idx_back+chunk_frames, :] # shrink the mel mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])(mel[None])[0] # logging.info(f"mel_shrink.shape: {mel_shrink.shape}") # stack mel_fusion = torch.stack([mel_chunk_front, mel_chunk_middle, mel_chunk_back, mel_shrink], dim=0) sample["mel_fusion"] = mel_fusion longer = torch.tensor([True]) else: raise NotImplementedError( f"data_truncating {data_truncating} not implemented" ) # random crop to max_len (for compatibility) overflow = len(audio_data) - max_len idx = np.random.randint(0, overflow + 1) audio_data = audio_data[idx: idx + max_len] else: # padding if too short if len(audio_data) < max_len: # do nothing if equal if data_filling == "repeatpad": n_repeat = int(max_len/len(audio_data)) audio_data = audio_data.repeat(n_repeat) # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0) # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0] audio_data = F.pad( audio_data, (0, max_len - len(audio_data)), mode="constant", value=0, ) elif data_filling == "pad": audio_data = F.pad( audio_data, (0, max_len - len(audio_data)), mode="constant", value=0, ) elif data_filling == "repeat": n_repeat = int(max_len/len(audio_data)) audio_data = audio_data.repeat(n_repeat+1)[:max_len] else: raise NotImplementedError( f"data_filling {data_filling} not implemented" ) if data_truncating == 'fusion': mel = get_mel(audio_data, audio_cfg) mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) sample["mel_fusion"] = mel_fusion longer = torch.tensor([False]) sample["longer"] = longer sample["waveform"] = audio_data return sample