import logging import math from dataclasses import dataclass, field from typing import Optional from omegaconf import II import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist from fairseq.modules import EMAModule, EMAModuleConfig from fairseq.data.data_utils import compute_mask_indices from fairseq.models import BaseFairseqModel, register_model from fairseq.models.wav2vec import ( ConvFeatureExtractionModel, Wav2Vec2Config, TransformerEncoder, ) from fairseq.modules import ( GradMultiply, LayerNorm, ) from fairseq.utils import index_put logger = logging.getLogger(__name__) @dataclass class Data2VecAudioConfig(Wav2Vec2Config): loss_beta: float = field( default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"} ) loss_scale: Optional[float] = field( default=None, metadata={ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)" }, ) average_top_k_layers: int = field( default=8, metadata={"help": "how many layers to average"} ) layer_norm_target_layer: bool = False instance_norm_target_layer: bool = False instance_norm_targets: bool = False layer_norm_targets: bool = False batch_norm_target_layer: bool = False group_norm_target_layer: bool = False ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"}) ema_end_decay: float = field( default=0.9999, metadata={"help": "final ema decay rate"} ) # when to finish annealing ema decay rate ema_anneal_end_step: int = II("optimization.max_update") ema_transformer_only: bool = field( default=True, metadata={"help": "whether to momentum update only the transformer"}, ) ema_layers_only: bool = field( default=True, metadata={"help": "whether to momentum update only the transformer layers"}, ) max_update: int = II("optimization.max_update") min_target_var: float = field( default=0.1, metadata={"help": "stop training if target var falls below this"} ) min_pred_var: float = field( default=0.01, metadata={"help": "stop training if prediction var falls below this"}, ) def get_annealed_rate(start, end, curr_step, total_steps): r = end - start pct_remaining = 1 - curr_step / total_steps return end - r * pct_remaining @register_model("data2vec_audio", dataclass=Data2VecAudioConfig) class Data2VecAudioModel(BaseFairseqModel): def __init__(self, cfg: Data2VecAudioConfig): super().__init__() self.cfg = cfg feature_enc_layers = eval(cfg.conv_feature_layers) self.extractor_embed = feature_enc_layers[-1][0] self.ema = None self.embed = cfg.encoder_embed_dim self.average_top_k_layers = cfg.average_top_k_layers self.loss_beta = cfg.loss_beta self.loss_scale = cfg.loss_scale self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias, ) self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim) self.mask_prob = cfg.mask_prob self.mask_selection = cfg.mask_selection self.mask_other = cfg.mask_other self.mask_length = cfg.mask_length self.no_mask_overlap = cfg.no_mask_overlap self.mask_min_space = cfg.mask_min_space self.mask_channel_prob = cfg.mask_channel_prob self.mask_channel_before = cfg.mask_channel_before self.mask_channel_selection = cfg.mask_channel_selection self.mask_channel_other = cfg.mask_channel_other self.mask_channel_length = cfg.mask_channel_length self.no_mask_channel_overlap = cfg.no_mask_channel_overlap self.mask_channel_min_space = cfg.mask_channel_min_space self.dropout_input = nn.Dropout(cfg.dropout_input) self.dropout_features = nn.Dropout(cfg.dropout_features) self.feature_grad_mult = cfg.feature_grad_mult self.mask_emb = nn.Parameter( torch.FloatTensor(cfg.encoder_embed_dim).uniform_() ) self.encoder = TransformerEncoder(cfg) self.layer_norm = LayerNorm(self.extractor_embed) self.final_proj = nn.Linear(self.embed, self.embed) self.num_updates = 0 def make_ema_teacher(self): ema_config = EMAModuleConfig( ema_decay=self.cfg.ema_decay, ema_fp32=True, ) skip_keys = set() if self.cfg.ema_layers_only: self.cfg.ema_transformer_only = True for k, _ in self.encoder.pos_conv.named_parameters(): skip_keys.add(f"pos_conv.{k}") self.ema = EMAModule( self.encoder if self.cfg.ema_transformer_only else self, ema_config, skip_keys=skip_keys, ) def set_num_updates(self, num_updates): super().set_num_updates(num_updates) if self.ema is None and self.final_proj is not None: logger.info(f"making ema teacher") self.make_ema_teacher() elif self.training and self.ema is not None: if self.cfg.ema_decay != self.cfg.ema_end_decay: if num_updates >= self.cfg.ema_anneal_end_step: decay = self.cfg.ema_end_decay else: decay = get_annealed_rate( self.cfg.ema_decay, self.cfg.ema_end_decay, num_updates, self.cfg.ema_anneal_end_step, ) self.ema.set_decay(decay) if self.ema.get_decay() < 1: self.ema.step(self.encoder if self.cfg.ema_transformer_only else self) self.num_updates = num_updates def state_dict(self, destination=None, prefix="", keep_vars=False): state = super().state_dict(destination, prefix, keep_vars) if self.ema is not None: state[prefix + "_ema"] = self.ema.fp32_params return state def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): if self.ema is not None: k = prefix + "_ema" assert k in state_dict self.ema.restore(state_dict[k], True) del state_dict[k] return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) @classmethod def build_model(cls, cfg: Data2VecAudioConfig, task=None): """Build a new model instance.""" return cls(cfg) def apply_mask( self, x, padding_mask, mask_indices=None, mask_channel_indices=None, ): B, T, C = x.shape if self.mask_channel_prob > 0 and self.mask_channel_before: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = ( torch.from_numpy(mask_channel_indices) .to(x.device) .unsqueeze(1) .expand(-1, T, -1) ) x[mask_channel_indices] = 0 if self.mask_prob > 0: if mask_indices is None: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=1, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, require_same_masks=self.cfg.require_same_masks, mask_dropout=self.cfg.mask_dropout, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x = index_put(x, mask_indices, self.mask_emb) else: mask_indices = None if self.mask_channel_prob > 0 and not self.mask_channel_before: if mask_channel_indices is None: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = ( torch.from_numpy(mask_channel_indices) .to(x.device) .unsqueeze(1) .expand(-1, T, -1) ) x = index_put(x, mask_channel_indices, 0) return x, mask_indices def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ Computes the output length of the convolutional layers """ def _conv_out_length(input_length, kernel_size, stride): return torch.floor((input_length - kernel_size) / stride + 1) conv_cfg_list = eval(self.cfg.conv_feature_layers) for i in range(len(conv_cfg_list)): input_lengths = _conv_out_length( input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2] ) return input_lengths.to(torch.long) def forward( self, source, padding_mask=None, mask=True, features_only=False, layer=None, mask_indices=None, mask_channel_indices=None, padding_count=None, ): features = source if self.feature_grad_mult > 0: features = self.feature_extractor(features) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor(features) features = features.transpose(1, 2) features = self.layer_norm(features) orig_padding_mask = padding_mask if padding_mask is not None and padding_mask.any(): input_lengths = (1 - padding_mask.long()).sum(-1) # apply conv formula to get real output_lengths output_lengths = self._get_feat_extract_output_lengths(input_lengths) padding_mask = torch.zeros( features.shape[:2], dtype=features.dtype, device=features.device ) # these two operations makes sure that all values # before the output lengths indices are attended to padding_mask[ ( torch.arange(padding_mask.shape[0], device=padding_mask.device), output_lengths - 1, ) ] = 1 padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() else: padding_mask = None if self.post_extract_proj is not None: features = self.post_extract_proj(features) pre_encoder_features = None if self.cfg.ema_transformer_only: pre_encoder_features = features.clone() features = self.dropout_input(features) if mask: x, mask_indices = self.apply_mask( features, padding_mask, mask_indices=mask_indices, mask_channel_indices=mask_channel_indices, ) else: x = features mask_indices = None x, layer_results = self.encoder( x, padding_mask=padding_mask, layer=layer, ) if features_only: return { "x": x, "padding_mask": padding_mask, "layer_results": layer_results, } result = { "losses": {}, } with torch.no_grad(): self.ema.model.eval() if self.cfg.ema_transformer_only: y, layer_results = self.ema.model.extract_features( pre_encoder_features, padding_mask=padding_mask, min_layer=self.cfg.encoder_layers - self.average_top_k_layers, ) y = { "x": y, "padding_mask": padding_mask, "layer_results": layer_results, } else: y = self.ema.model.extract_features( source=source, padding_mask=orig_padding_mask, mask=False, ) target_layer_results = [l[2] for l in y["layer_results"]] permuted = False if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer: target_layer_results = [ tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT ] permuted = True if self.cfg.batch_norm_target_layer: target_layer_results = [ F.batch_norm( tl.float(), running_mean=None, running_var=None, training=True ) for tl in target_layer_results ] if self.cfg.instance_norm_target_layer: target_layer_results = [ F.instance_norm(tl.float()) for tl in target_layer_results ] if permuted: target_layer_results = [ tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC ] if self.cfg.group_norm_target_layer: target_layer_results = [ F.layer_norm(tl.float(), tl.shape[-2:]) for tl in target_layer_results ] if self.cfg.layer_norm_target_layer: target_layer_results = [ F.layer_norm(tl.float(), tl.shape[-1:]) for tl in target_layer_results ] y = sum(target_layer_results) / len(target_layer_results) if self.cfg.layer_norm_targets: y = F.layer_norm(y.float(), y.shape[-1:]) if self.cfg.instance_norm_targets: y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2) if not permuted: y = y.transpose(0, 1) y = y[mask_indices] x = x[mask_indices] x = self.final_proj(x) sz = x.size(-1) if self.loss_beta == 0: loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1) else: loss = F.smooth_l1_loss( x.float(), y.float(), reduction="none", beta=self.loss_beta ).sum(dim=-1) if self.loss_scale is not None: scale = self.loss_scale else: scale = 1 / math.sqrt(sz) result["losses"]["regression"] = loss.sum() * scale if "sample_size" not in result: result["sample_size"] = loss.numel() with torch.no_grad(): result["target_var"] = self.compute_var(y) result["pred_var"] = self.compute_var(x.float()) if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var: logger.error( f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting" ) raise Exception( f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting" ) if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var: logger.error( f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting" ) raise Exception( f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting" ) if self.ema is not None: result["ema_decay"] = self.ema.get_decay() * 1000 return result @staticmethod def compute_var(y): y = y.view(-1, y.size(-1)) if dist.is_initialized(): zc = torch.tensor(y.size(0)).cuda() zs = y.sum(dim=0) zss = (y ** 2).sum(dim=0) dist.all_reduce(zc) dist.all_reduce(zs) dist.all_reduce(zss) var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1)) return torch.sqrt(var + 1e-6).mean() else: return torch.sqrt(y.var(dim=0) + 1e-6).mean() def extract_features( self, source, padding_mask, mask=False, layer=None ): res = self.forward( source, padding_mask, mask=mask, features_only=True, layer=layer, ) return res def remove_pretraining_modules(self, last_layer=None): self.final_proj = None self.ema = None if last_layer is not None: self.encoder.layers = nn.ModuleList( l for i, l in enumerate(self.encoder.layers) if i <= last_layer ) import logging import torch import torch.nn.functional as F from fairseq import tasks from fairseq.checkpoint_utils import load_checkpoint_to_cpu from fairseq.data.audio.audio_utils import get_features_or_waveform from omegaconf import OmegaConf logger = logging.getLogger("dump_feature") class Data2vecFeatureReader(object): def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000): state = load_checkpoint_to_cpu(ckpt_path) cfg = state["cfg"] # load task task = tasks.setup_task(cfg.task, from_checkpoint=True) task.load_state_dict(state["task_state"]) # load model config if "layer_type" not in cfg.model: # fix a missing key model_config = {k: v for k, v in cfg.model.items()} model_config["layer_type"] = "transformer" model_config = OmegaConf.create(model_config) else: model_config = cfg.model # fix param name in the state state["model"]["final_proj.weight"] = state["model"].pop("final_proj.0.weight") state["model"]["final_proj.bias"] = state["model"].pop("final_proj.0.bias") del state["model"]["_ema"] # load model model = Data2VecAudioModel.build_model(model_config) model.load_state_dict( state["model"], strict=True, model_cfg=model_config ) self.device = device logger.info(f"device = {self.device}") self.model = model.eval().to(self.device) self.task = task self.layer = layer - 1 # make it 1-based self.max_chunk = max_chunk logger.info(f"TASK CONFIG:\n{self.task.cfg}") logger.info(f" max_chunk = {self.max_chunk}") def read_audio(self, path, ref_len=None): wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate) if wav.ndim == 2: wav = wav.mean(-1) assert wav.ndim == 1, wav.ndim if ref_len is not None and abs(ref_len - len(wav)) > 160: logger.warning(f"ref {ref_len} != read {len(wav)} ({path})") return wav def get_feats(self, path, ref_len=None): x = self.read_audio(path, ref_len=ref_len) with torch.no_grad(): x = torch.from_numpy(x).float().to(self.device) if self.task.cfg.normalize: x = F.layer_norm(x, x.shape) x = x.view(1, -1) feat = [] for start in range(0, x.size(1), self.max_chunk): x_chunk = x[:, start: start + self.max_chunk] res = self.model.extract_features( source=x_chunk, padding_mask=None, mask=False, layer=self.layer, ) feat_chunk = res["x"] feat.append(feat_chunk) return torch.cat(feat, 1).squeeze(0)