import random import re import torch import torch.nn as nn import torchaudio from torch.utils.data import Dataset class DataCollator: def __init__(self, processor, padding, device, augment): self.processor = processor self.padding = padding self.device = device self.sampling_rate = 16000 self.augment = augment atempos = (0.8, 1.0, 1.25) # audio tempo atempo=tempo audio_effects = ( ("highpass=frequency=1500",), ( "vibrato=f=5:d=0.4", "volume=1.5", ), ( "aecho=0.8:0.88:30:0.3", "volume=1.5", ), ) self.effectors = [None] for atempo in atempos: for audio_effect in audio_effects: effect = f"atempo={atempo}," + ",".join(audio_effect) self.effectors.append(torchaudio.io.AudioEffector(effect=effect)) def __call__(self, data): waveforms, lm_labels, accent_labels, gender_labels = zip(*data) accent_labels = torch.tensor(accent_labels, device=self.device) gender_labels = torch.tensor(gender_labels, device=self.device) input_features = [ {"input_values": self.random_augment(waveform).squeeze()} for waveform in waveforms ] label_features = [{"input_ids": lm_label} for lm_label in lm_labels] padded_waveforms = self.processor.pad( input_features, padding=True, return_tensors="pt", )["input_values"] padded_waveforms = padded_waveforms.to(self.device) with self.processor.as_target_processor(): padded_lm_labels = self.processor.pad( label_features, padding=True, return_tensors="pt", ) # replace padding with -100 to ignore loss correctly padded_lm_labels = padded_lm_labels["input_ids"].masked_fill( padded_lm_labels.attention_mask.ne(1), -100 ) padded_lm_labels = padded_lm_labels.to(self.device) return padded_waveforms, padded_lm_labels, accent_labels, gender_labels def random_augment(self, waveform): if not self.augment: return waveform waveform = torch.tensor(waveform) waveform = torch.transpose(waveform, 0, 1) effector = random.choice(self.effectors) if effector is None: return waveform augmented_waveform = effector.apply(waveform, self.sampling_rate) if augmented_waveform.isnan().any() | augmented_waveform.isinf().any(): return waveform return augmented_waveform class L2ArcticDataset(Dataset): def __init__(self, processor, audio_paths, lm_labels, accent_labels, gender_labels): orig_sampling_rate = 44100 new_sampling_rate = 16000 resample_transform = torchaudio.transforms.Resample( orig_sampling_rate, new_sampling_rate ) self.waveforms = [] self.lm_labels = [] self.accent_labels = accent_labels self.gender_labels = gender_labels for audio_path in audio_paths: waveform, _ = torchaudio.load(audio_path) waveform = resample_transform(waveform) self.waveforms.append( processor(waveform, sampling_rate=new_sampling_rate).input_values[0] ) with processor.as_target_processor(): for lm_label in lm_labels: self.lm_labels.append(processor(lm_label).input_ids) def __getitem__(self, index): return ( self.waveforms[index], self.lm_labels[index], self.accent_labels[index], self.gender_labels[index], ) def __len__(self): return len(self.waveforms) class MultiTaskWav2Vec2(nn.Module): def __init__( self, wav2vec2_backbone, backbone_hidden_size, projection_hidden_size, num_accent_class, ): super().__init__() self.wav2vec2 = wav2vec2_backbone self.accent_projector = nn.Linear(backbone_hidden_size, projection_hidden_size) self.accent_classifier = nn.Linear(projection_hidden_size, num_accent_class) self.gender_projector = nn.Linear(backbone_hidden_size, projection_hidden_size) self.gender_classifier = nn.Linear(projection_hidden_size, 2) def forward(self, waveform, lm_labels=None): if lm_labels is not None: # use hugging face wav2vecc2 wav2vec2_output = self.wav2vec2(input_values=waveform, labels=lm_labels) # get partial loss based (lm_head loss or the ctc loss) ctc_loss = wav2vec2_output.loss else: # use hugging face wav2vecc2 wav2vec2_output = self.wav2vec2(input_values=waveform) ctc_loss = None # get features from wav2vec2 features = wav2vec2_output.hidden_states[-1] # get output lm logits lm_logits = wav2vec2_output.logits # get output accent logits accent_projected = self.accent_projector(features) accent_projected = accent_projected.mean(dim=1) accent_logits = self.accent_classifier(accent_projected) # get output gender logits gender_projected = self.gender_projector(features) gender_projected = gender_projected.mean(dim=1) gender_logits = self.gender_classifier(gender_projected) return ctc_loss, lm_logits, accent_logits, gender_logits