|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import logging |
|
import math |
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from omegaconf import II |
|
|
|
from s3prl.upstream.wav2vec2.wav2vec2_model import ( |
|
ConvFeatureExtractionModel, |
|
GradMultiply, |
|
LayerNorm, |
|
TransformerEncoder, |
|
Wav2Vec2Config, |
|
compute_mask_indices, |
|
index_put, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EMAModuleConfig: |
|
ema_decay: float = field( |
|
default=0.9999, metadata={"help": "decay for exponential moving average model"} |
|
) |
|
ema_fp32: bool = field( |
|
default=False, |
|
metadata={"help": "If true, store EMA model in fp32 even if model is in fp16"}, |
|
) |
|
|
|
|
|
class EMAModule: |
|
"""Exponential Moving Average of Fairseq Models""" |
|
|
|
def __init__(self, model, config: EMAModuleConfig, device=None, skip_keys=None): |
|
""" |
|
@param model model to initialize the EMA with |
|
@param config EMAConfig object with configuration like |
|
ema_decay, ema_update_freq, ema_fp32 |
|
@param device If provided, copy EMA to this device (e.g. gpu). |
|
Otherwise EMA is in the same device as the model. |
|
""" |
|
|
|
self.decay = config.ema_decay |
|
self.model = copy.deepcopy(model) |
|
self.model.requires_grad_(False) |
|
self.config = config |
|
self.skip_keys = skip_keys or set() |
|
self.fp32_params = {} |
|
|
|
if device is not None: |
|
logging.info(f"Copying EMA model to device {device}") |
|
self.model = self.model.to(device=device) |
|
|
|
if self.config.ema_fp32: |
|
self.build_fp32_params() |
|
|
|
self.update_freq_counter = 0 |
|
|
|
def build_fp32_params(self, state_dict=None): |
|
""" |
|
Store a copy of the EMA params in fp32. |
|
If state dict is passed, the EMA params is copied from |
|
the provided state dict. Otherwise, it is copied from the |
|
current EMA model parameters. |
|
""" |
|
if not self.config.ema_fp32: |
|
raise RuntimeError( |
|
"build_fp32_params should not be called if ema_fp32=False. " |
|
"Use ema_fp32=True if this is really intended." |
|
) |
|
|
|
if state_dict is None: |
|
state_dict = self.model.state_dict() |
|
|
|
def _to_float(t): |
|
return t.float() if torch.is_floating_point(t) else t |
|
|
|
for param_key in state_dict: |
|
if param_key in self.fp32_params: |
|
self.fp32_params[param_key].copy_(state_dict[param_key]) |
|
else: |
|
self.fp32_params[param_key] = _to_float(state_dict[param_key]) |
|
|
|
def restore(self, state_dict, build_fp32_params=False): |
|
"""Load data from a model spec into EMA model""" |
|
self.model.load_state_dict(state_dict, strict=False) |
|
if build_fp32_params: |
|
self.build_fp32_params(state_dict) |
|
|
|
def set_decay(self, decay): |
|
self.decay = decay |
|
|
|
def get_decay(self): |
|
return self.decay |
|
|
|
def _step_internal(self, new_model): |
|
"""One update of the EMA model based on new model weights""" |
|
decay = self.decay |
|
|
|
ema_state_dict = {} |
|
ema_params = ( |
|
self.fp32_params if self.config.ema_fp32 else self.model.state_dict() |
|
) |
|
for key, param in new_model.named_parameters(): |
|
if isinstance(param, dict): |
|
continue |
|
try: |
|
ema_param = ema_params[key] |
|
except KeyError: |
|
ema_param = ( |
|
param.float().clone() if param.ndim == 1 else copy.deepcopy(param) |
|
) |
|
ema_params[key] = ema_param |
|
|
|
if param.shape != ema_param.shape: |
|
raise ValueError( |
|
"incompatible tensor shapes between model param and ema param" |
|
+ "{} vs. {}".format(param.shape, ema_param.shape) |
|
) |
|
|
|
if "version" in key: |
|
|
|
continue |
|
|
|
if key in self.skip_keys or not param.requires_grad: |
|
ema_params[key].copy_(param.to(dtype=ema_param.dtype).data) |
|
ema_param = ema_params[key] |
|
else: |
|
ema_param.mul_(decay) |
|
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - decay) |
|
|
|
ema_state_dict[key] = ema_param |
|
|
|
for key, param in new_model.named_buffers(): |
|
ema_state_dict[key] = param |
|
|
|
self.restore(ema_state_dict, build_fp32_params=False) |
|
|
|
@torch.no_grad() |
|
def step(self, new_model): |
|
self._step_internal(new_model) |
|
|
|
def reverse(self, model): |
|
""" |
|
Load the model parameters from EMA model. |
|
Useful for inference or fine-tuning from the EMA model. |
|
""" |
|
d = self.model.state_dict() |
|
if "_ema" in d: |
|
del d["_ema"] |
|
|
|
model.load_state_dict(d, strict=False) |
|
return model |
|
|
|
|
|
@dataclass |
|
class Data2VecAudioConfig(Wav2Vec2Config): |
|
loss_beta: float = field( |
|
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"} |
|
) |
|
loss_scale: Optional[float] = field( |
|
default=None, |
|
metadata={ |
|
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)" |
|
}, |
|
) |
|
average_top_k_layers: int = field( |
|
default=8, metadata={"help": "how many layers to average"} |
|
) |
|
|
|
layer_norm_target_layer: bool = False |
|
instance_norm_target_layer: bool = False |
|
instance_norm_targets: bool = False |
|
layer_norm_targets: bool = False |
|
batch_norm_target_layer: bool = False |
|
group_norm_target_layer: bool = False |
|
|
|
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"}) |
|
ema_end_decay: float = field( |
|
default=0.9999, metadata={"help": "final ema decay rate"} |
|
) |
|
|
|
|
|
ema_anneal_end_step: int = None |
|
|
|
ema_transformer_only: bool = field( |
|
default=True, |
|
metadata={"help": "whether to momentum update only the transformer"}, |
|
) |
|
ema_layers_only: bool = field( |
|
default=True, |
|
metadata={"help": "whether to momentum update only the transformer layers"}, |
|
) |
|
|
|
max_update: int = None |
|
|
|
min_target_var: float = field( |
|
default=0.1, metadata={"help": "stop training if target var falls below this"} |
|
) |
|
min_pred_var: float = field( |
|
default=0.01, |
|
metadata={"help": "stop training if prediction var falls below this"}, |
|
) |
|
|
|
|
|
def get_annealed_rate(start, end, curr_step, total_steps): |
|
r = end - start |
|
pct_remaining = 1 - curr_step / total_steps |
|
return end - r * pct_remaining |
|
|
|
|
|
class Data2VecAudioModel(torch.nn.Module): |
|
def __init__(self, cfg: Data2VecAudioConfig): |
|
super().__init__() |
|
self.cfg = cfg |
|
|
|
feature_enc_layers = eval(cfg.conv_feature_layers) |
|
self.extractor_embed = feature_enc_layers[-1][0] |
|
|
|
self.ema = None |
|
self.embed = cfg.encoder_embed_dim |
|
|
|
self.average_top_k_layers = cfg.average_top_k_layers |
|
self.loss_beta = cfg.loss_beta |
|
self.loss_scale = cfg.loss_scale |
|
|
|
self.feature_extractor = ConvFeatureExtractionModel( |
|
conv_layers=feature_enc_layers, |
|
dropout=0.0, |
|
mode=cfg.extractor_mode, |
|
conv_bias=cfg.conv_bias, |
|
) |
|
|
|
self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim) |
|
|
|
self.mask_prob = cfg.mask_prob |
|
self.mask_selection = cfg.mask_selection |
|
self.mask_other = cfg.mask_other |
|
self.mask_length = cfg.mask_length |
|
self.no_mask_overlap = cfg.no_mask_overlap |
|
self.mask_min_space = cfg.mask_min_space |
|
|
|
self.mask_channel_prob = cfg.mask_channel_prob |
|
self.mask_channel_before = cfg.mask_channel_before |
|
self.mask_channel_selection = cfg.mask_channel_selection |
|
self.mask_channel_other = cfg.mask_channel_other |
|
self.mask_channel_length = cfg.mask_channel_length |
|
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap |
|
self.mask_channel_min_space = cfg.mask_channel_min_space |
|
|
|
self.dropout_input = nn.Dropout(cfg.dropout_input) |
|
self.dropout_features = nn.Dropout(cfg.dropout_features) |
|
|
|
self.feature_grad_mult = cfg.feature_grad_mult |
|
|
|
self.mask_emb = nn.Parameter( |
|
torch.FloatTensor(cfg.encoder_embed_dim).uniform_() |
|
) |
|
|
|
self.encoder = TransformerEncoder(cfg) |
|
self.layer_norm = LayerNorm(self.extractor_embed) |
|
|
|
self.final_proj = nn.Linear(self.embed, self.embed) |
|
|
|
self.num_updates = 0 |
|
|
|
def make_ema_teacher(self): |
|
ema_config = EMAModuleConfig( |
|
ema_decay=self.cfg.ema_decay, |
|
ema_fp32=True, |
|
) |
|
skip_keys = set() |
|
if self.cfg.ema_layers_only: |
|
self.cfg.ema_transformer_only = True |
|
for k, _ in self.encoder.pos_conv.named_parameters(): |
|
skip_keys.add(f"pos_conv.{k}") |
|
|
|
self.ema = EMAModule( |
|
self.encoder if self.cfg.ema_transformer_only else self, |
|
ema_config, |
|
skip_keys=skip_keys, |
|
) |
|
|
|
def set_num_updates(self, num_updates): |
|
super().set_num_updates(num_updates) |
|
|
|
if self.ema is None and self.final_proj is not None: |
|
logger.info(f"making ema teacher") |
|
self.make_ema_teacher() |
|
elif self.training and self.ema is not None: |
|
if self.cfg.ema_decay != self.cfg.ema_end_decay: |
|
if num_updates >= self.cfg.ema_anneal_end_step: |
|
decay = self.cfg.ema_end_decay |
|
else: |
|
decay = get_annealed_rate( |
|
self.cfg.ema_decay, |
|
self.cfg.ema_end_decay, |
|
num_updates, |
|
self.cfg.ema_anneal_end_step, |
|
) |
|
self.ema.set_decay(decay) |
|
if self.ema.get_decay() < 1: |
|
self.ema.step(self.encoder if self.cfg.ema_transformer_only else self) |
|
|
|
self.num_updates = num_updates |
|
|
|
def state_dict(self, destination=None, prefix="", keep_vars=False): |
|
state = super().state_dict(destination, prefix, keep_vars) |
|
|
|
if self.ema is not None: |
|
state[prefix + "_ema"] = self.ema.fp32_params |
|
|
|
return state |
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): |
|
if self.ema is not None: |
|
k = prefix + "_ema" |
|
assert k in state_dict |
|
self.ema.restore(state_dict[k], True) |
|
del state_dict[k] |
|
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
@classmethod |
|
def build_model(cls, cfg: Data2VecAudioConfig, task=None): |
|
"""Build a new model instance.""" |
|
|
|
return cls(cfg) |
|
|
|
def apply_mask( |
|
self, |
|
x, |
|
padding_mask, |
|
mask_indices=None, |
|
mask_channel_indices=None, |
|
): |
|
B, T, C = x.shape |
|
|
|
if self.mask_channel_prob > 0 and self.mask_channel_before: |
|
mask_channel_indices = compute_mask_indices( |
|
(B, C), |
|
None, |
|
self.mask_channel_prob, |
|
self.mask_channel_length, |
|
self.mask_channel_selection, |
|
self.mask_channel_other, |
|
no_overlap=self.no_mask_channel_overlap, |
|
min_space=self.mask_channel_min_space, |
|
) |
|
mask_channel_indices = ( |
|
torch.from_numpy(mask_channel_indices) |
|
.to(x.device) |
|
.unsqueeze(1) |
|
.expand(-1, T, -1) |
|
) |
|
x[mask_channel_indices] = 0 |
|
|
|
if self.mask_prob > 0: |
|
if mask_indices is None: |
|
mask_indices = compute_mask_indices( |
|
(B, T), |
|
padding_mask, |
|
self.mask_prob, |
|
self.mask_length, |
|
self.mask_selection, |
|
self.mask_other, |
|
min_masks=1, |
|
no_overlap=self.no_mask_overlap, |
|
min_space=self.mask_min_space, |
|
require_same_masks=self.cfg.require_same_masks, |
|
mask_dropout=self.cfg.mask_dropout, |
|
) |
|
mask_indices = torch.from_numpy(mask_indices).to(x.device) |
|
x = index_put(x, mask_indices, self.mask_emb) |
|
else: |
|
mask_indices = None |
|
|
|
if self.mask_channel_prob > 0 and not self.mask_channel_before: |
|
if mask_channel_indices is None: |
|
mask_channel_indices = compute_mask_indices( |
|
(B, C), |
|
None, |
|
self.mask_channel_prob, |
|
self.mask_channel_length, |
|
self.mask_channel_selection, |
|
self.mask_channel_other, |
|
no_overlap=self.no_mask_channel_overlap, |
|
min_space=self.mask_channel_min_space, |
|
) |
|
mask_channel_indices = ( |
|
torch.from_numpy(mask_channel_indices) |
|
.to(x.device) |
|
.unsqueeze(1) |
|
.expand(-1, T, -1) |
|
) |
|
x = index_put(x, mask_channel_indices, 0) |
|
|
|
return x, mask_indices |
|
|
|
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): |
|
""" |
|
Computes the output length of the convolutional layers |
|
""" |
|
|
|
def _conv_out_length(input_length, kernel_size, stride): |
|
return torch.floor((input_length - kernel_size) / stride + 1) |
|
|
|
conv_cfg_list = eval(self.cfg.conv_feature_layers) |
|
|
|
for i in range(len(conv_cfg_list)): |
|
input_lengths = _conv_out_length( |
|
input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2] |
|
) |
|
|
|
return input_lengths.to(torch.long) |
|
|
|
def forward( |
|
self, |
|
source, |
|
padding_mask=None, |
|
mask=True, |
|
features_only=False, |
|
layer=None, |
|
mask_indices=None, |
|
mask_channel_indices=None, |
|
padding_count=None, |
|
): |
|
features = source |
|
|
|
if self.feature_grad_mult > 0: |
|
features = self.feature_extractor(features) |
|
if self.feature_grad_mult != 1.0: |
|
features = GradMultiply.apply(features, self.feature_grad_mult) |
|
else: |
|
with torch.no_grad(): |
|
features = self.feature_extractor(features) |
|
|
|
features = features.transpose(1, 2) |
|
|
|
features = self.layer_norm(features) |
|
|
|
orig_padding_mask = padding_mask |
|
|
|
if padding_mask is not None and padding_mask.any(): |
|
input_lengths = (1 - padding_mask.long()).sum(-1) |
|
|
|
output_lengths = self._get_feat_extract_output_lengths(input_lengths) |
|
|
|
padding_mask = torch.zeros( |
|
features.shape[:2], dtype=features.dtype, device=features.device |
|
) |
|
|
|
|
|
|
|
padding_mask[ |
|
( |
|
torch.arange(padding_mask.shape[0], device=padding_mask.device), |
|
output_lengths - 1, |
|
) |
|
] = 1 |
|
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() |
|
else: |
|
padding_mask = None |
|
|
|
if self.post_extract_proj is not None: |
|
features = self.post_extract_proj(features) |
|
|
|
pre_encoder_features = None |
|
if self.cfg.ema_transformer_only: |
|
pre_encoder_features = features.clone() |
|
|
|
features = self.dropout_input(features) |
|
|
|
if mask: |
|
x, mask_indices = self.apply_mask( |
|
features, |
|
padding_mask, |
|
mask_indices=mask_indices, |
|
mask_channel_indices=mask_channel_indices, |
|
) |
|
else: |
|
x = features |
|
mask_indices = None |
|
|
|
x, layer_results = self.encoder( |
|
x, |
|
padding_mask=padding_mask, |
|
layer=layer, |
|
) |
|
|
|
if features_only: |
|
return { |
|
"x": x, |
|
"padding_mask": padding_mask, |
|
"layer_results": layer_results, |
|
} |
|
|
|
result = { |
|
"losses": {}, |
|
} |
|
|
|
with torch.no_grad(): |
|
self.ema.model.eval() |
|
|
|
if self.cfg.ema_transformer_only: |
|
y, layer_results = self.ema.model.extract_features( |
|
pre_encoder_features, |
|
padding_mask=padding_mask, |
|
min_layer=self.cfg.encoder_layers - self.average_top_k_layers, |
|
) |
|
y = { |
|
"x": y, |
|
"padding_mask": padding_mask, |
|
"layer_results": layer_results, |
|
} |
|
else: |
|
y = self.ema.model.extract_features( |
|
source=source, |
|
padding_mask=orig_padding_mask, |
|
mask=False, |
|
) |
|
|
|
target_layer_results = [l[2] for l in y["layer_results"]] |
|
|
|
permuted = False |
|
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer: |
|
target_layer_results = [ |
|
tl.permute(1, 2, 0) for tl in target_layer_results |
|
] |
|
permuted = True |
|
|
|
if self.cfg.batch_norm_target_layer: |
|
target_layer_results = [ |
|
F.batch_norm( |
|
tl.float(), running_mean=None, running_var=None, training=True |
|
) |
|
for tl in target_layer_results |
|
] |
|
|
|
if self.cfg.instance_norm_target_layer: |
|
target_layer_results = [ |
|
F.instance_norm(tl.float()) for tl in target_layer_results |
|
] |
|
|
|
if permuted: |
|
target_layer_results = [ |
|
tl.transpose(1, 2) for tl in target_layer_results |
|
] |
|
|
|
if self.cfg.group_norm_target_layer: |
|
target_layer_results = [ |
|
F.layer_norm(tl.float(), tl.shape[-2:]) |
|
for tl in target_layer_results |
|
] |
|
|
|
if self.cfg.layer_norm_target_layer: |
|
target_layer_results = [ |
|
F.layer_norm(tl.float(), tl.shape[-1:]) |
|
for tl in target_layer_results |
|
] |
|
|
|
y = sum(target_layer_results) / len(target_layer_results) |
|
|
|
if self.cfg.layer_norm_targets: |
|
y = F.layer_norm(y.float(), y.shape[-1:]) |
|
|
|
if self.cfg.instance_norm_targets: |
|
y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2) |
|
|
|
if not permuted: |
|
y = y.transpose(0, 1) |
|
|
|
y = y[mask_indices] |
|
|
|
x = x[mask_indices] |
|
x = self.final_proj(x) |
|
|
|
sz = x.size(-1) |
|
|
|
if self.loss_beta == 0: |
|
loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1) |
|
else: |
|
loss = F.smooth_l1_loss( |
|
x.float(), y.float(), reduction="none", beta=self.loss_beta |
|
).sum(dim=-1) |
|
|
|
if self.loss_scale is not None: |
|
scale = self.loss_scale |
|
else: |
|
scale = 1 / math.sqrt(sz) |
|
|
|
result["losses"]["regression"] = loss.sum() * scale |
|
|
|
if "sample_size" not in result: |
|
result["sample_size"] = loss.numel() |
|
|
|
with torch.no_grad(): |
|
result["target_var"] = self.compute_var(y) |
|
result["pred_var"] = self.compute_var(x.float()) |
|
|
|
if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var: |
|
logger.error( |
|
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting" |
|
) |
|
raise Exception( |
|
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting" |
|
) |
|
if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var: |
|
logger.error( |
|
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting" |
|
) |
|
raise Exception( |
|
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting" |
|
) |
|
|
|
if self.ema is not None: |
|
result["ema_decay"] = self.ema.get_decay() * 1000 |
|
|
|
return result |
|
|
|
@staticmethod |
|
def compute_var(y): |
|
y = y.view(-1, y.size(-1)) |
|
if dist.is_initialized(): |
|
zc = torch.tensor(y.size(0)).cuda() |
|
zs = y.sum(dim=0) |
|
zss = (y**2).sum(dim=0) |
|
|
|
dist.all_reduce(zc) |
|
dist.all_reduce(zs) |
|
dist.all_reduce(zss) |
|
|
|
var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1)) |
|
return torch.sqrt(var + 1e-6).mean() |
|
else: |
|
return torch.sqrt(y.var(dim=0) + 1e-6).mean() |
|
|
|
def extract_features(self, source, padding_mask, mask=False, layer=None): |
|
res = self.forward( |
|
source, |
|
padding_mask, |
|
mask=mask, |
|
features_only=True, |
|
layer=layer, |
|
) |
|
return res |
|
|
|
def remove_pretraining_modules(self, last_layer=None): |
|
self.final_proj = None |
|
self.ema = None |
|
if last_layer is not None: |
|
self.encoder.layers = nn.ModuleList( |
|
l for i, l in enumerate(self.encoder.layers) if i <= last_layer |
|
) |
|
|