Spaces:
Running
Running
import os | |
import sys | |
import math | |
import torch | |
import inspect | |
import functools | |
sys.path.append(os.getcwd()) | |
from main.library.speaker_diarization.speechbrain import MAIN_PROC_ONLY, is_distributed_initialized, main_process_only | |
KEYS_MAPPING = {".mutihead_attn": ".multihead_attn", ".convs_intermedite": ".convs_intermediate"} | |
def map_old_state_dict_weights(state_dict, mapping): | |
for replacement_old, replacement_new in mapping.items(): | |
for old_key in list(state_dict.keys()): | |
if replacement_old in old_key: state_dict[old_key.replace(replacement_old, replacement_new)] = state_dict.pop(old_key) | |
return state_dict | |
def hook_on_loading_state_dict_checkpoint(state_dict): | |
return map_old_state_dict_weights(state_dict, KEYS_MAPPING) | |
def torch_patched_state_dict_load(path, device="cpu"): | |
return hook_on_loading_state_dict_checkpoint(torch.load(path, map_location=device)) | |
def torch_save(obj, path): | |
state_dict = obj.state_dict() | |
torch.save(state_dict, path) | |
def torch_recovery(obj, path, end_of_epoch): | |
del end_of_epoch | |
state_dict = torch_patched_state_dict_load(path, "cpu") | |
try: | |
obj.load_state_dict(state_dict, strict=True) | |
except TypeError: | |
obj.load_state_dict(state_dict) | |
def torch_parameter_transfer(obj, path): | |
incompatible_keys = obj.load_state_dict(torch_patched_state_dict_load(path, "cpu"), strict=False) | |
for missing_key in incompatible_keys.missing_keys: | |
pass | |
for unexpected_key in incompatible_keys.unexpected_keys: | |
pass | |
WEAKREF_MARKER = "WEAKREF" | |
def _cycliclrsaver(obj, path): | |
state_dict = obj.state_dict() | |
if state_dict.get("_scale_fn_ref") is not None: state_dict["_scale_fn_ref"] = WEAKREF_MARKER | |
torch.save(state_dict, path) | |
def _cycliclrloader(obj, path, end_of_epoch): | |
del end_of_epoch | |
try: | |
obj.load_state_dict(torch.load(path, map_location="cpu"), strict=True) | |
except TypeError: | |
obj.load_state_dict(torch.load(path, map_location="cpu")) | |
DEFAULT_LOAD_HOOKS = {torch.nn.Module: torch_recovery, torch.optim.Optimizer: torch_recovery, torch.optim.lr_scheduler.ReduceLROnPlateau: torch_recovery, torch.cuda.amp.grad_scaler.GradScaler: torch_recovery} | |
DEFAULT_SAVE_HOOKS = { torch.nn.Module: torch_save, torch.optim.Optimizer: torch_save, torch.optim.lr_scheduler.ReduceLROnPlateau: torch_save, torch.cuda.amp.grad_scaler.GradScaler: torch_save} | |
DEFAULT_LOAD_HOOKS[torch.optim.lr_scheduler.LRScheduler] = torch_recovery | |
DEFAULT_SAVE_HOOKS[torch.optim.lr_scheduler.LRScheduler] = torch_save | |
DEFAULT_TRANSFER_HOOKS = {torch.nn.Module: torch_parameter_transfer} | |
DEFAULT_SAVE_HOOKS[torch.optim.lr_scheduler.CyclicLR] = _cycliclrsaver | |
DEFAULT_LOAD_HOOKS[torch.optim.lr_scheduler.CyclicLR] = _cycliclrloader | |
def register_checkpoint_hooks(cls, save_on_main_only=True): | |
global DEFAULT_LOAD_HOOKS, DEFAULT_SAVE_HOOKS, DEFAULT_TRANSFER_HOOKS | |
for name, method in cls.__dict__.items(): | |
if hasattr(method, "_speechbrain_saver"): DEFAULT_SAVE_HOOKS[cls] = main_process_only(method) if save_on_main_only else method | |
if hasattr(method, "_speechbrain_loader"): DEFAULT_LOAD_HOOKS[cls] = method | |
if hasattr(method, "_speechbrain_transfer"): DEFAULT_TRANSFER_HOOKS[cls] = method | |
return cls | |
def mark_as_saver(method): | |
sig = inspect.signature(method) | |
try: | |
sig.bind(object(), os.path.join("testpath")) | |
except TypeError: | |
raise TypeError | |
method._speechbrain_saver = True | |
return method | |
def mark_as_transfer(method): | |
sig = inspect.signature(method) | |
try: | |
sig.bind(object(), os.path.join("testpath")) | |
except TypeError: | |
raise TypeError | |
method._speechbrain_transfer = True | |
return method | |
def mark_as_loader(method): | |
sig = inspect.signature(method) | |
try: | |
sig.bind(object(), os.path.join("testpath"), True) | |
except TypeError: | |
raise TypeError | |
method._speechbrain_loader = True | |
return method | |
def ddp_all_reduce(communication_object, reduce_op): | |
if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized(): return communication_object | |
torch.distributed.all_reduce(communication_object, op=reduce_op) | |
return communication_object | |
def fwd_default_precision(fwd = None, cast_inputs = torch.float32): | |
if fwd is None: return functools.partial(fwd_default_precision, cast_inputs=cast_inputs) | |
wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs) | |
def wrapper(*args, force_allow_autocast = False, **kwargs): | |
return fwd(*args, **kwargs) if force_allow_autocast else wrapped_fwd(*args, **kwargs) | |
return wrapper | |
def spectral_magnitude(stft, power = 1, log = False, eps = 1e-14): | |
spectr = stft.pow(2).sum(-1) | |
if power < 1: spectr = spectr + eps | |
spectr = spectr.pow(power) | |
if log: return torch.log(spectr + eps) | |
return spectr | |
class Filterbank(torch.nn.Module): | |
def __init__(self, n_mels=40, log_mel=True, filter_shape="triangular", f_min=0, f_max=8000, n_fft=400, sample_rate=16000, power_spectrogram=2, amin=1e-10, ref_value=1.0, top_db=80.0, param_change_factor=1.0, param_rand_factor=0.0, freeze=True): | |
super().__init__() | |
self.n_mels = n_mels | |
self.log_mel = log_mel | |
self.filter_shape = filter_shape | |
self.f_min = f_min | |
self.f_max = f_max | |
self.n_fft = n_fft | |
self.sample_rate = sample_rate | |
self.power_spectrogram = power_spectrogram | |
self.amin = amin | |
self.ref_value = ref_value | |
self.top_db = top_db | |
self.freeze = freeze | |
self.n_stft = self.n_fft // 2 + 1 | |
self.db_multiplier = math.log10(max(self.amin, self.ref_value)) | |
self.device_inp = torch.device("cpu") | |
self.param_change_factor = param_change_factor | |
self.param_rand_factor = param_rand_factor | |
self.multiplier = 10 if self.power_spectrogram == 2 else 20 | |
hz = self._to_hz(torch.linspace(self._to_mel(self.f_min), self._to_mel(self.f_max), self.n_mels + 2)) | |
band = hz[1:] - hz[:-1] | |
self.band = band[:-1] | |
self.f_central = hz[1:-1] | |
if not self.freeze: | |
self.f_central = torch.nn.Parameter(self.f_central / (self.sample_rate * self.param_change_factor)) | |
self.band = torch.nn.Parameter(self.band / (self.sample_rate * self.param_change_factor)) | |
self.all_freqs_mat = torch.linspace(0, self.sample_rate // 2, self.n_stft).repeat(self.f_central.shape[0], 1) | |
def forward(self, spectrogram): | |
f_central_mat = self.f_central.repeat(self.all_freqs_mat.shape[1], 1).transpose(0, 1) | |
band_mat = self.band.repeat(self.all_freqs_mat.shape[1], 1).transpose(0, 1) | |
if not self.freeze: | |
f_central_mat = f_central_mat * (self.sample_rate * self.param_change_factor * self.param_change_factor) | |
band_mat = band_mat * (self.sample_rate * self.param_change_factor * self.param_change_factor) | |
elif self.param_rand_factor != 0 and self.training: | |
rand_change = (1.0 + torch.rand(2) * 2 * self.param_rand_factor - self.param_rand_factor) | |
f_central_mat = f_central_mat * rand_change[0] | |
band_mat = band_mat * rand_change[1] | |
fbank_matrix = self._create_fbank_matrix(f_central_mat, band_mat).to(spectrogram.device) | |
sp_shape = spectrogram.shape | |
if len(sp_shape) == 4: spectrogram = spectrogram.permute(0, 3, 1, 2).reshape(sp_shape[0] * sp_shape[3], sp_shape[1], sp_shape[2]) | |
fbanks = torch.matmul(spectrogram, fbank_matrix) | |
if self.log_mel: fbanks = self._amplitude_to_DB(fbanks) | |
if len(sp_shape) == 4: | |
fb_shape = fbanks.shape | |
fbanks = fbanks.reshape(sp_shape[0], sp_shape[3], fb_shape[1], fb_shape[2]).permute(0, 2, 3, 1) | |
return fbanks | |
def _to_mel(hz): | |
return 2595 * math.log10(1 + hz / 700) | |
def _to_hz(mel): | |
return 700 * (10 ** (mel / 2595) - 1) | |
def _triangular_filters(self, all_freqs, f_central, band): | |
slope = (all_freqs - f_central) / band | |
return torch.max(torch.zeros(1, device=self.device_inp), torch.min(slope + 1.0, -slope + 1.0)).transpose(0, 1) | |
def _rectangular_filters(self, all_freqs, f_central, band): | |
left_side = right_size = all_freqs.ge(f_central - band) | |
right_size = all_freqs.le(f_central + band) | |
return (left_side * right_size).float().transpose(0, 1) | |
def _gaussian_filters(self, all_freqs, f_central, band, smooth_factor=torch.tensor(2)): | |
return torch.exp(-0.5 * ((all_freqs - f_central) / (band / smooth_factor)) ** 2).transpose(0, 1) | |
def _create_fbank_matrix(self, f_central_mat, band_mat): | |
if self.filter_shape == "triangular": fbank_matrix = self._triangular_filters(self.all_freqs_mat, f_central_mat, band_mat) | |
elif self.filter_shape == "rectangular": fbank_matrix = self._rectangular_filters(self.all_freqs_mat, f_central_mat, band_mat) | |
else: fbank_matrix = self._gaussian_filters(self.all_freqs_mat, f_central_mat, band_mat) | |
return fbank_matrix | |
def _amplitude_to_DB(self, x): | |
x_db = self.multiplier * torch.log10(torch.clamp(x, min=self.amin)) | |
x_db -= self.multiplier * self.db_multiplier | |
return torch.max(x_db, (x_db.amax(dim=(-2, -1)) - self.top_db).view(x_db.shape[0], 1, 1)) | |
class ContextWindow(torch.nn.Module): | |
def __init__(self, left_frames=0, right_frames=0): | |
super().__init__() | |
self.left_frames = left_frames | |
self.right_frames = right_frames | |
self.context_len = self.left_frames + self.right_frames + 1 | |
self.kernel_len = 2 * max(self.left_frames, self.right_frames) + 1 | |
self.kernel = torch.eye(self.context_len, self.kernel_len) | |
if self.right_frames > self.left_frames: self.kernel = torch.roll(self.kernel, self.right_frames - self.left_frames, 1) | |
self.first_call = True | |
def forward(self, x): | |
x = x.transpose(1, 2) | |
if self.first_call: | |
self.first_call = False | |
self.kernel = (self.kernel.repeat(x.shape[1], 1, 1).view(x.shape[1] * self.context_len, self.kernel_len).unsqueeze(1)) | |
or_shape = x.shape | |
if len(or_shape) == 4: x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3]) | |
cw_x = torch.nn.functional.conv1d(x, self.kernel.to(x.device), groups=x.shape[1], padding=max(self.left_frames, self.right_frames)) | |
if len(or_shape) == 4: cw_x = cw_x.reshape(or_shape[0], cw_x.shape[1], or_shape[2], cw_x.shape[-1]) | |
return cw_x.transpose(1, 2) | |
class FilterProperties: | |
def __init__(self, window_size = 0, stride = 1, dilation = 1, causal = False): | |
self.window_size = window_size | |
self.stride = stride | |
self.dilation = dilation | |
self.causal = causal | |
def __post_init__(self): | |
assert self.window_size > 0 | |
assert self.stride > 0 | |
assert (self.dilation > 0) | |
def pointwise_filter(): | |
return FilterProperties(window_size=1, stride=1) | |
def get_effective_size(self): | |
return 1 + ((self.window_size - 1) * self.dilation) | |
def get_convolution_padding(self): | |
if self.window_size % 2 == 0: raise ValueError | |
if self.causal: return self.get_effective_size() - 1 | |
return (self.get_effective_size() - 1) // 2 | |
def get_noncausal_equivalent(self): | |
if not self.causal: return self | |
return FilterProperties(window_size=(self.window_size - 1) * 2 + 1, stride=self.stride, dilation=self.dilation, causal=False) | |
def with_on_top(self, other, allow_approximate=True): | |
self_size = self.window_size | |
if other.window_size % 2 == 0: | |
if allow_approximate: other_size = other.window_size + 1 | |
else: raise ValueError | |
else: other_size = other.window_size | |
if (self.causal or other.causal) and not (self.causal and other.causal): | |
if allow_approximate: return self.get_noncausal_equivalent().with_on_top(other.get_noncausal_equivalent()) | |
else: raise ValueError | |
return FilterProperties(self_size + (self.stride * (other_size - 1)), self.stride * other.stride, self.dilation * other.dilation, self.causal) | |
class STFT(torch.nn.Module): | |
def __init__(self, sample_rate, win_length=25, hop_length=10, n_fft=400, window_fn=torch.hamming_window, normalized_stft=False, center=True, pad_mode="constant", onesided=True): | |
super().__init__() | |
self.sample_rate = sample_rate | |
self.win_length = win_length | |
self.hop_length = hop_length | |
self.n_fft = n_fft | |
self.normalized_stft = normalized_stft | |
self.center = center | |
self.pad_mode = pad_mode | |
self.onesided = onesided | |
self.win_length = int(round((self.sample_rate / 1000.0) * self.win_length)) | |
self.hop_length = int(round((self.sample_rate / 1000.0) * self.hop_length)) | |
self.window = window_fn(self.win_length) | |
def forward(self, x): | |
or_shape = x.shape | |
if len(or_shape) == 3: x = x.transpose(1, 2).reshape(or_shape[0] * or_shape[2], or_shape[1]) | |
stft = torch.view_as_real(torch.stft(x, self.n_fft, self.hop_length, self.win_length, self.window.to(x.device), self.center, self.pad_mode, self.normalized_stft, self.onesided, return_complex=True)) | |
stft = stft.reshape(or_shape[0], or_shape[2], stft.shape[1], stft.shape[2], stft.shape[3]).permute(0, 3, 2, 4, 1) if len(or_shape) == 3 else stft.transpose(2, 1) | |
return stft | |
def get_filter_properties(self): | |
if not self.center: raise ValueError | |
return FilterProperties(window_size=self.win_length, stride=self.hop_length) | |
class Deltas(torch.nn.Module): | |
def __init__(self, input_size, window_length=5): | |
super().__init__() | |
self.n = (window_length - 1) // 2 | |
self.denom = self.n * (self.n + 1) * (2 * self.n + 1) / 3 | |
self.register_buffer("kernel", torch.arange(-self.n, self.n + 1, dtype=torch.float32).repeat(input_size, 1, 1),) | |
def forward(self, x): | |
x = x.transpose(1, 2).transpose(2, -1) | |
or_shape = x.shape | |
if len(or_shape) == 4: x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3]) | |
x = torch.nn.functional.pad(x, (self.n, self.n), mode="replicate") | |
delta_coeff = (torch.nn.functional.conv1d(x, self.kernel.to(x.device), groups=x.shape[1]) / self.denom) | |
if len(or_shape) == 4: delta_coeff = delta_coeff.reshape(or_shape[0], or_shape[1], or_shape[2], or_shape[3]) | |
return delta_coeff.transpose(1, -1).transpose(2, -1) | |
class Fbank(torch.nn.Module): | |
def __init__(self, deltas=False, context=False, requires_grad=False, sample_rate=16000, f_min=0, f_max=None, n_fft=400, n_mels=40, filter_shape="triangular", param_change_factor=1.0, param_rand_factor=0.0, left_frames=5, right_frames=5, win_length=25, hop_length=10): | |
super().__init__() | |
self.deltas = deltas | |
self.context = context | |
self.requires_grad = requires_grad | |
if f_max is None: f_max = sample_rate / 2 | |
self.compute_STFT = STFT(sample_rate=sample_rate,n_fft=n_fft,win_length=win_length,hop_length=hop_length) | |
self.compute_fbanks = Filterbank(sample_rate=sample_rate,n_fft=n_fft,n_mels=n_mels,f_min=f_min,f_max=f_max,freeze=not requires_grad,filter_shape=filter_shape,param_change_factor=param_change_factor,param_rand_factor=param_rand_factor) | |
self.compute_deltas = Deltas(input_size=n_mels) | |
self.context_window = ContextWindow(left_frames=left_frames, right_frames=right_frames) | |
def forward(self, wav): | |
fbanks = self.compute_fbanks(spectral_magnitude(self.compute_STFT(wav))) | |
if self.deltas: | |
delta1 = self.compute_deltas(fbanks) | |
fbanks = torch.cat([fbanks, delta1, self.compute_deltas(delta1)], dim=2) | |
if self.context: fbanks = self.context_window(fbanks) | |
return fbanks | |
def get_filter_properties(self): | |
return self.compute_STFT.get_filter_properties() | |
class InputNormalization(torch.nn.Module): | |
def __init__(self, mean_norm=True, std_norm=True, norm_type="global", avg_factor=None, requires_grad=False, update_until_epoch=3): | |
super().__init__() | |
self.mean_norm = mean_norm | |
self.std_norm = std_norm | |
self.norm_type = norm_type | |
self.avg_factor = avg_factor | |
self.requires_grad = requires_grad | |
self.glob_mean = torch.tensor([0]) | |
self.glob_std = torch.tensor([0]) | |
self.spk_dict_mean = {} | |
self.spk_dict_std = {} | |
self.spk_dict_count = {} | |
self.weight = 1.0 | |
self.count = 0 | |
self.eps = 1e-10 | |
self.update_until_epoch = update_until_epoch | |
def forward(self, x, lengths, spk_ids = torch.tensor([]), epoch=0): | |
N_batches = x.shape[0] | |
current_means, current_stds = [], [] | |
if self.norm_type == "sentence" or self.norm_type == "speaker": out = torch.empty_like(x) | |
for snt_id in range(N_batches): | |
actual_size = torch.round(lengths[snt_id] * x.shape[1]).int() | |
current_mean, current_std = self._compute_current_stats(x[snt_id, 0:actual_size, ...]) | |
current_means.append(current_mean) | |
current_stds.append(current_std) | |
if self.norm_type == "sentence": out[snt_id] = (x[snt_id] - current_mean.data) / current_std.data | |
if self.norm_type == "speaker": | |
spk_id = int(spk_ids[snt_id][0]) | |
if self.training: | |
if spk_id not in self.spk_dict_mean: | |
self.spk_dict_mean[spk_id] = current_mean | |
self.spk_dict_std[spk_id] = current_std | |
self.spk_dict_count[spk_id] = 1 | |
else: | |
self.spk_dict_count[spk_id] = (self.spk_dict_count[spk_id] + 1) | |
self.weight = (1 / self.spk_dict_count[spk_id]) if self.avg_factor is None else self.avg_factor | |
self.spk_dict_mean[spk_id] = (1 - self.weight) * self.spk_dict_mean[spk_id].to(current_mean) + self.weight * current_mean | |
self.spk_dict_std[spk_id] = (1 - self.weight) * self.spk_dict_std[spk_id].to(current_std) + self.weight * current_std | |
self.spk_dict_mean[spk_id].detach() | |
self.spk_dict_std[spk_id].detach() | |
speaker_mean = self.spk_dict_mean[spk_id].data | |
speaker_std = self.spk_dict_std[spk_id].data | |
else: | |
if spk_id in self.spk_dict_mean: | |
speaker_mean = self.spk_dict_mean[spk_id].data | |
speaker_std = self.spk_dict_std[spk_id].data | |
else: | |
speaker_mean = current_mean.data | |
speaker_std = current_std.data | |
out[snt_id] = (x[snt_id] - speaker_mean) / speaker_std | |
if self.norm_type == "batch" or self.norm_type == "global": | |
current_mean = ddp_all_reduce(torch.mean(torch.stack(current_means), dim=0), torch.distributed.ReduceOp.AVG) | |
current_std = ddp_all_reduce(torch.mean(torch.stack(current_stds), dim=0), torch.distributed.ReduceOp.AVG) | |
if self.norm_type == "batch": out = (x - current_mean.data) / (current_std.data) | |
if self.norm_type == "global": | |
if self.training: | |
if self.count == 0: | |
self.glob_mean = current_mean | |
self.glob_std = current_std | |
elif epoch is None or epoch < self.update_until_epoch: | |
self.weight = (1 / (self.count + 1)) if self.avg_factor is None else self.avg_factor | |
self.glob_mean = (1 - self.weight) * self.glob_mean.to(current_mean) + self.weight * current_mean | |
self.glob_std = (1 - self.weight) * self.glob_std.to(current_std) + self.weight * current_std | |
self.glob_mean.detach() | |
self.glob_std.detach() | |
self.count = self.count + 1 | |
out = (x - self.glob_mean.data.to(x)) / (self.glob_std.data.to(x)) | |
return out | |
def _compute_current_stats(self, x): | |
current_std = torch.std(x, dim=0).detach().data if self.std_norm else torch.tensor([1.0], device=x.device) | |
return torch.mean(x, dim=0).detach().data if self.mean_norm else torch.tensor([0.0], device=x.device), torch.max(current_std, self.eps * torch.ones_like(current_std)) | |
def _statistics_dict(self): | |
state = {} | |
state["count"] = self.count | |
state["glob_mean"] = self.glob_mean | |
state["glob_std"] = self.glob_std | |
state["spk_dict_mean"] = self.spk_dict_mean | |
state["spk_dict_std"] = self.spk_dict_std | |
state["spk_dict_count"] = self.spk_dict_count | |
return state | |
def _load_statistics_dict(self, state): | |
self.count = state["count"] | |
if isinstance(state["glob_mean"], int): | |
self.glob_mean = state["glob_mean"] | |
self.glob_std = state["glob_std"] | |
else: | |
self.glob_mean = state["glob_mean"] | |
self.glob_std = state["glob_std"] | |
self.spk_dict_mean = {} | |
for spk in state["spk_dict_mean"]: | |
self.spk_dict_mean[spk] = state["spk_dict_mean"][spk] | |
self.spk_dict_std = {} | |
for spk in state["spk_dict_std"]: | |
self.spk_dict_std[spk] = state["spk_dict_std"][spk] | |
self.spk_dict_count = state["spk_dict_count"] | |
return state | |
def to(self, device): | |
self = super(InputNormalization, self).to(device) | |
self.glob_mean = self.glob_mean.to(device) | |
self.glob_std = self.glob_std.to(device) | |
for spk in self.spk_dict_mean: | |
self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device) | |
self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device) | |
return self | |
def _save(self, path): | |
torch.save(self._statistics_dict(), path) | |
def _load(self, path, end_of_epoch=False): | |
del end_of_epoch | |
stats = torch.load(path, map_location="cpu") | |
self._load_statistics_dict(stats) |