NeoPy's picture
Upload 115 files
96134ee verified
raw
history blame
22.6 kB
import os
import sys
import math
import torch
import inspect
import functools
sys.path.append(os.getcwd())
from main.library.speaker_diarization.speechbrain import MAIN_PROC_ONLY, is_distributed_initialized, main_process_only
KEYS_MAPPING = {".mutihead_attn": ".multihead_attn", ".convs_intermedite": ".convs_intermediate"}
def map_old_state_dict_weights(state_dict, mapping):
for replacement_old, replacement_new in mapping.items():
for old_key in list(state_dict.keys()):
if replacement_old in old_key: state_dict[old_key.replace(replacement_old, replacement_new)] = state_dict.pop(old_key)
return state_dict
def hook_on_loading_state_dict_checkpoint(state_dict):
return map_old_state_dict_weights(state_dict, KEYS_MAPPING)
def torch_patched_state_dict_load(path, device="cpu"):
return hook_on_loading_state_dict_checkpoint(torch.load(path, map_location=device))
@main_process_only
def torch_save(obj, path):
state_dict = obj.state_dict()
torch.save(state_dict, path)
def torch_recovery(obj, path, end_of_epoch):
del end_of_epoch
state_dict = torch_patched_state_dict_load(path, "cpu")
try:
obj.load_state_dict(state_dict, strict=True)
except TypeError:
obj.load_state_dict(state_dict)
def torch_parameter_transfer(obj, path):
incompatible_keys = obj.load_state_dict(torch_patched_state_dict_load(path, "cpu"), strict=False)
for missing_key in incompatible_keys.missing_keys:
pass
for unexpected_key in incompatible_keys.unexpected_keys:
pass
WEAKREF_MARKER = "WEAKREF"
def _cycliclrsaver(obj, path):
state_dict = obj.state_dict()
if state_dict.get("_scale_fn_ref") is not None: state_dict["_scale_fn_ref"] = WEAKREF_MARKER
torch.save(state_dict, path)
def _cycliclrloader(obj, path, end_of_epoch):
del end_of_epoch
try:
obj.load_state_dict(torch.load(path, map_location="cpu"), strict=True)
except TypeError:
obj.load_state_dict(torch.load(path, map_location="cpu"))
DEFAULT_LOAD_HOOKS = {torch.nn.Module: torch_recovery, torch.optim.Optimizer: torch_recovery, torch.optim.lr_scheduler.ReduceLROnPlateau: torch_recovery, torch.cuda.amp.grad_scaler.GradScaler: torch_recovery}
DEFAULT_SAVE_HOOKS = { torch.nn.Module: torch_save, torch.optim.Optimizer: torch_save, torch.optim.lr_scheduler.ReduceLROnPlateau: torch_save, torch.cuda.amp.grad_scaler.GradScaler: torch_save}
DEFAULT_LOAD_HOOKS[torch.optim.lr_scheduler.LRScheduler] = torch_recovery
DEFAULT_SAVE_HOOKS[torch.optim.lr_scheduler.LRScheduler] = torch_save
DEFAULT_TRANSFER_HOOKS = {torch.nn.Module: torch_parameter_transfer}
DEFAULT_SAVE_HOOKS[torch.optim.lr_scheduler.CyclicLR] = _cycliclrsaver
DEFAULT_LOAD_HOOKS[torch.optim.lr_scheduler.CyclicLR] = _cycliclrloader
def register_checkpoint_hooks(cls, save_on_main_only=True):
global DEFAULT_LOAD_HOOKS, DEFAULT_SAVE_HOOKS, DEFAULT_TRANSFER_HOOKS
for name, method in cls.__dict__.items():
if hasattr(method, "_speechbrain_saver"): DEFAULT_SAVE_HOOKS[cls] = main_process_only(method) if save_on_main_only else method
if hasattr(method, "_speechbrain_loader"): DEFAULT_LOAD_HOOKS[cls] = method
if hasattr(method, "_speechbrain_transfer"): DEFAULT_TRANSFER_HOOKS[cls] = method
return cls
def mark_as_saver(method):
sig = inspect.signature(method)
try:
sig.bind(object(), os.path.join("testpath"))
except TypeError:
raise TypeError
method._speechbrain_saver = True
return method
def mark_as_transfer(method):
sig = inspect.signature(method)
try:
sig.bind(object(), os.path.join("testpath"))
except TypeError:
raise TypeError
method._speechbrain_transfer = True
return method
def mark_as_loader(method):
sig = inspect.signature(method)
try:
sig.bind(object(), os.path.join("testpath"), True)
except TypeError:
raise TypeError
method._speechbrain_loader = True
return method
def ddp_all_reduce(communication_object, reduce_op):
if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized(): return communication_object
torch.distributed.all_reduce(communication_object, op=reduce_op)
return communication_object
def fwd_default_precision(fwd = None, cast_inputs = torch.float32):
if fwd is None: return functools.partial(fwd_default_precision, cast_inputs=cast_inputs)
wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
@functools.wraps(fwd)
def wrapper(*args, force_allow_autocast = False, **kwargs):
return fwd(*args, **kwargs) if force_allow_autocast else wrapped_fwd(*args, **kwargs)
return wrapper
def spectral_magnitude(stft, power = 1, log = False, eps = 1e-14):
spectr = stft.pow(2).sum(-1)
if power < 1: spectr = spectr + eps
spectr = spectr.pow(power)
if log: return torch.log(spectr + eps)
return spectr
class Filterbank(torch.nn.Module):
def __init__(self, n_mels=40, log_mel=True, filter_shape="triangular", f_min=0, f_max=8000, n_fft=400, sample_rate=16000, power_spectrogram=2, amin=1e-10, ref_value=1.0, top_db=80.0, param_change_factor=1.0, param_rand_factor=0.0, freeze=True):
super().__init__()
self.n_mels = n_mels
self.log_mel = log_mel
self.filter_shape = filter_shape
self.f_min = f_min
self.f_max = f_max
self.n_fft = n_fft
self.sample_rate = sample_rate
self.power_spectrogram = power_spectrogram
self.amin = amin
self.ref_value = ref_value
self.top_db = top_db
self.freeze = freeze
self.n_stft = self.n_fft // 2 + 1
self.db_multiplier = math.log10(max(self.amin, self.ref_value))
self.device_inp = torch.device("cpu")
self.param_change_factor = param_change_factor
self.param_rand_factor = param_rand_factor
self.multiplier = 10 if self.power_spectrogram == 2 else 20
hz = self._to_hz(torch.linspace(self._to_mel(self.f_min), self._to_mel(self.f_max), self.n_mels + 2))
band = hz[1:] - hz[:-1]
self.band = band[:-1]
self.f_central = hz[1:-1]
if not self.freeze:
self.f_central = torch.nn.Parameter(self.f_central / (self.sample_rate * self.param_change_factor))
self.band = torch.nn.Parameter(self.band / (self.sample_rate * self.param_change_factor))
self.all_freqs_mat = torch.linspace(0, self.sample_rate // 2, self.n_stft).repeat(self.f_central.shape[0], 1)
def forward(self, spectrogram):
f_central_mat = self.f_central.repeat(self.all_freqs_mat.shape[1], 1).transpose(0, 1)
band_mat = self.band.repeat(self.all_freqs_mat.shape[1], 1).transpose(0, 1)
if not self.freeze:
f_central_mat = f_central_mat * (self.sample_rate * self.param_change_factor * self.param_change_factor)
band_mat = band_mat * (self.sample_rate * self.param_change_factor * self.param_change_factor)
elif self.param_rand_factor != 0 and self.training:
rand_change = (1.0 + torch.rand(2) * 2 * self.param_rand_factor - self.param_rand_factor)
f_central_mat = f_central_mat * rand_change[0]
band_mat = band_mat * rand_change[1]
fbank_matrix = self._create_fbank_matrix(f_central_mat, band_mat).to(spectrogram.device)
sp_shape = spectrogram.shape
if len(sp_shape) == 4: spectrogram = spectrogram.permute(0, 3, 1, 2).reshape(sp_shape[0] * sp_shape[3], sp_shape[1], sp_shape[2])
fbanks = torch.matmul(spectrogram, fbank_matrix)
if self.log_mel: fbanks = self._amplitude_to_DB(fbanks)
if len(sp_shape) == 4:
fb_shape = fbanks.shape
fbanks = fbanks.reshape(sp_shape[0], sp_shape[3], fb_shape[1], fb_shape[2]).permute(0, 2, 3, 1)
return fbanks
@staticmethod
def _to_mel(hz):
return 2595 * math.log10(1 + hz / 700)
@staticmethod
def _to_hz(mel):
return 700 * (10 ** (mel / 2595) - 1)
def _triangular_filters(self, all_freqs, f_central, band):
slope = (all_freqs - f_central) / band
return torch.max(torch.zeros(1, device=self.device_inp), torch.min(slope + 1.0, -slope + 1.0)).transpose(0, 1)
def _rectangular_filters(self, all_freqs, f_central, band):
left_side = right_size = all_freqs.ge(f_central - band)
right_size = all_freqs.le(f_central + band)
return (left_side * right_size).float().transpose(0, 1)
def _gaussian_filters(self, all_freqs, f_central, band, smooth_factor=torch.tensor(2)):
return torch.exp(-0.5 * ((all_freqs - f_central) / (band / smooth_factor)) ** 2).transpose(0, 1)
def _create_fbank_matrix(self, f_central_mat, band_mat):
if self.filter_shape == "triangular": fbank_matrix = self._triangular_filters(self.all_freqs_mat, f_central_mat, band_mat)
elif self.filter_shape == "rectangular": fbank_matrix = self._rectangular_filters(self.all_freqs_mat, f_central_mat, band_mat)
else: fbank_matrix = self._gaussian_filters(self.all_freqs_mat, f_central_mat, band_mat)
return fbank_matrix
def _amplitude_to_DB(self, x):
x_db = self.multiplier * torch.log10(torch.clamp(x, min=self.amin))
x_db -= self.multiplier * self.db_multiplier
return torch.max(x_db, (x_db.amax(dim=(-2, -1)) - self.top_db).view(x_db.shape[0], 1, 1))
class ContextWindow(torch.nn.Module):
def __init__(self, left_frames=0, right_frames=0):
super().__init__()
self.left_frames = left_frames
self.right_frames = right_frames
self.context_len = self.left_frames + self.right_frames + 1
self.kernel_len = 2 * max(self.left_frames, self.right_frames) + 1
self.kernel = torch.eye(self.context_len, self.kernel_len)
if self.right_frames > self.left_frames: self.kernel = torch.roll(self.kernel, self.right_frames - self.left_frames, 1)
self.first_call = True
def forward(self, x):
x = x.transpose(1, 2)
if self.first_call:
self.first_call = False
self.kernel = (self.kernel.repeat(x.shape[1], 1, 1).view(x.shape[1] * self.context_len, self.kernel_len).unsqueeze(1))
or_shape = x.shape
if len(or_shape) == 4: x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3])
cw_x = torch.nn.functional.conv1d(x, self.kernel.to(x.device), groups=x.shape[1], padding=max(self.left_frames, self.right_frames))
if len(or_shape) == 4: cw_x = cw_x.reshape(or_shape[0], cw_x.shape[1], or_shape[2], cw_x.shape[-1])
return cw_x.transpose(1, 2)
class FilterProperties:
def __init__(self, window_size = 0, stride = 1, dilation = 1, causal = False):
self.window_size = window_size
self.stride = stride
self.dilation = dilation
self.causal = causal
def __post_init__(self):
assert self.window_size > 0
assert self.stride > 0
assert (self.dilation > 0)
@staticmethod
def pointwise_filter():
return FilterProperties(window_size=1, stride=1)
def get_effective_size(self):
return 1 + ((self.window_size - 1) * self.dilation)
def get_convolution_padding(self):
if self.window_size % 2 == 0: raise ValueError
if self.causal: return self.get_effective_size() - 1
return (self.get_effective_size() - 1) // 2
def get_noncausal_equivalent(self):
if not self.causal: return self
return FilterProperties(window_size=(self.window_size - 1) * 2 + 1, stride=self.stride, dilation=self.dilation, causal=False)
def with_on_top(self, other, allow_approximate=True):
self_size = self.window_size
if other.window_size % 2 == 0:
if allow_approximate: other_size = other.window_size + 1
else: raise ValueError
else: other_size = other.window_size
if (self.causal or other.causal) and not (self.causal and other.causal):
if allow_approximate: return self.get_noncausal_equivalent().with_on_top(other.get_noncausal_equivalent())
else: raise ValueError
return FilterProperties(self_size + (self.stride * (other_size - 1)), self.stride * other.stride, self.dilation * other.dilation, self.causal)
class STFT(torch.nn.Module):
def __init__(self, sample_rate, win_length=25, hop_length=10, n_fft=400, window_fn=torch.hamming_window, normalized_stft=False, center=True, pad_mode="constant", onesided=True):
super().__init__()
self.sample_rate = sample_rate
self.win_length = win_length
self.hop_length = hop_length
self.n_fft = n_fft
self.normalized_stft = normalized_stft
self.center = center
self.pad_mode = pad_mode
self.onesided = onesided
self.win_length = int(round((self.sample_rate / 1000.0) * self.win_length))
self.hop_length = int(round((self.sample_rate / 1000.0) * self.hop_length))
self.window = window_fn(self.win_length)
def forward(self, x):
or_shape = x.shape
if len(or_shape) == 3: x = x.transpose(1, 2).reshape(or_shape[0] * or_shape[2], or_shape[1])
stft = torch.view_as_real(torch.stft(x, self.n_fft, self.hop_length, self.win_length, self.window.to(x.device), self.center, self.pad_mode, self.normalized_stft, self.onesided, return_complex=True))
stft = stft.reshape(or_shape[0], or_shape[2], stft.shape[1], stft.shape[2], stft.shape[3]).permute(0, 3, 2, 4, 1) if len(or_shape) == 3 else stft.transpose(2, 1)
return stft
def get_filter_properties(self):
if not self.center: raise ValueError
return FilterProperties(window_size=self.win_length, stride=self.hop_length)
class Deltas(torch.nn.Module):
def __init__(self, input_size, window_length=5):
super().__init__()
self.n = (window_length - 1) // 2
self.denom = self.n * (self.n + 1) * (2 * self.n + 1) / 3
self.register_buffer("kernel", torch.arange(-self.n, self.n + 1, dtype=torch.float32).repeat(input_size, 1, 1),)
def forward(self, x):
x = x.transpose(1, 2).transpose(2, -1)
or_shape = x.shape
if len(or_shape) == 4: x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3])
x = torch.nn.functional.pad(x, (self.n, self.n), mode="replicate")
delta_coeff = (torch.nn.functional.conv1d(x, self.kernel.to(x.device), groups=x.shape[1]) / self.denom)
if len(or_shape) == 4: delta_coeff = delta_coeff.reshape(or_shape[0], or_shape[1], or_shape[2], or_shape[3])
return delta_coeff.transpose(1, -1).transpose(2, -1)
class Fbank(torch.nn.Module):
def __init__(self, deltas=False, context=False, requires_grad=False, sample_rate=16000, f_min=0, f_max=None, n_fft=400, n_mels=40, filter_shape="triangular", param_change_factor=1.0, param_rand_factor=0.0, left_frames=5, right_frames=5, win_length=25, hop_length=10):
super().__init__()
self.deltas = deltas
self.context = context
self.requires_grad = requires_grad
if f_max is None: f_max = sample_rate / 2
self.compute_STFT = STFT(sample_rate=sample_rate,n_fft=n_fft,win_length=win_length,hop_length=hop_length)
self.compute_fbanks = Filterbank(sample_rate=sample_rate,n_fft=n_fft,n_mels=n_mels,f_min=f_min,f_max=f_max,freeze=not requires_grad,filter_shape=filter_shape,param_change_factor=param_change_factor,param_rand_factor=param_rand_factor)
self.compute_deltas = Deltas(input_size=n_mels)
self.context_window = ContextWindow(left_frames=left_frames, right_frames=right_frames)
@fwd_default_precision(cast_inputs=torch.float32)
def forward(self, wav):
fbanks = self.compute_fbanks(spectral_magnitude(self.compute_STFT(wav)))
if self.deltas:
delta1 = self.compute_deltas(fbanks)
fbanks = torch.cat([fbanks, delta1, self.compute_deltas(delta1)], dim=2)
if self.context: fbanks = self.context_window(fbanks)
return fbanks
def get_filter_properties(self):
return self.compute_STFT.get_filter_properties()
@register_checkpoint_hooks
class InputNormalization(torch.nn.Module):
def __init__(self, mean_norm=True, std_norm=True, norm_type="global", avg_factor=None, requires_grad=False, update_until_epoch=3):
super().__init__()
self.mean_norm = mean_norm
self.std_norm = std_norm
self.norm_type = norm_type
self.avg_factor = avg_factor
self.requires_grad = requires_grad
self.glob_mean = torch.tensor([0])
self.glob_std = torch.tensor([0])
self.spk_dict_mean = {}
self.spk_dict_std = {}
self.spk_dict_count = {}
self.weight = 1.0
self.count = 0
self.eps = 1e-10
self.update_until_epoch = update_until_epoch
def forward(self, x, lengths, spk_ids = torch.tensor([]), epoch=0):
N_batches = x.shape[0]
current_means, current_stds = [], []
if self.norm_type == "sentence" or self.norm_type == "speaker": out = torch.empty_like(x)
for snt_id in range(N_batches):
actual_size = torch.round(lengths[snt_id] * x.shape[1]).int()
current_mean, current_std = self._compute_current_stats(x[snt_id, 0:actual_size, ...])
current_means.append(current_mean)
current_stds.append(current_std)
if self.norm_type == "sentence": out[snt_id] = (x[snt_id] - current_mean.data) / current_std.data
if self.norm_type == "speaker":
spk_id = int(spk_ids[snt_id][0])
if self.training:
if spk_id not in self.spk_dict_mean:
self.spk_dict_mean[spk_id] = current_mean
self.spk_dict_std[spk_id] = current_std
self.spk_dict_count[spk_id] = 1
else:
self.spk_dict_count[spk_id] = (self.spk_dict_count[spk_id] + 1)
self.weight = (1 / self.spk_dict_count[spk_id]) if self.avg_factor is None else self.avg_factor
self.spk_dict_mean[spk_id] = (1 - self.weight) * self.spk_dict_mean[spk_id].to(current_mean) + self.weight * current_mean
self.spk_dict_std[spk_id] = (1 - self.weight) * self.spk_dict_std[spk_id].to(current_std) + self.weight * current_std
self.spk_dict_mean[spk_id].detach()
self.spk_dict_std[spk_id].detach()
speaker_mean = self.spk_dict_mean[spk_id].data
speaker_std = self.spk_dict_std[spk_id].data
else:
if spk_id in self.spk_dict_mean:
speaker_mean = self.spk_dict_mean[spk_id].data
speaker_std = self.spk_dict_std[spk_id].data
else:
speaker_mean = current_mean.data
speaker_std = current_std.data
out[snt_id] = (x[snt_id] - speaker_mean) / speaker_std
if self.norm_type == "batch" or self.norm_type == "global":
current_mean = ddp_all_reduce(torch.mean(torch.stack(current_means), dim=0), torch.distributed.ReduceOp.AVG)
current_std = ddp_all_reduce(torch.mean(torch.stack(current_stds), dim=0), torch.distributed.ReduceOp.AVG)
if self.norm_type == "batch": out = (x - current_mean.data) / (current_std.data)
if self.norm_type == "global":
if self.training:
if self.count == 0:
self.glob_mean = current_mean
self.glob_std = current_std
elif epoch is None or epoch < self.update_until_epoch:
self.weight = (1 / (self.count + 1)) if self.avg_factor is None else self.avg_factor
self.glob_mean = (1 - self.weight) * self.glob_mean.to(current_mean) + self.weight * current_mean
self.glob_std = (1 - self.weight) * self.glob_std.to(current_std) + self.weight * current_std
self.glob_mean.detach()
self.glob_std.detach()
self.count = self.count + 1
out = (x - self.glob_mean.data.to(x)) / (self.glob_std.data.to(x))
return out
def _compute_current_stats(self, x):
current_std = torch.std(x, dim=0).detach().data if self.std_norm else torch.tensor([1.0], device=x.device)
return torch.mean(x, dim=0).detach().data if self.mean_norm else torch.tensor([0.0], device=x.device), torch.max(current_std, self.eps * torch.ones_like(current_std))
def _statistics_dict(self):
state = {}
state["count"] = self.count
state["glob_mean"] = self.glob_mean
state["glob_std"] = self.glob_std
state["spk_dict_mean"] = self.spk_dict_mean
state["spk_dict_std"] = self.spk_dict_std
state["spk_dict_count"] = self.spk_dict_count
return state
def _load_statistics_dict(self, state):
self.count = state["count"]
if isinstance(state["glob_mean"], int):
self.glob_mean = state["glob_mean"]
self.glob_std = state["glob_std"]
else:
self.glob_mean = state["glob_mean"]
self.glob_std = state["glob_std"]
self.spk_dict_mean = {}
for spk in state["spk_dict_mean"]:
self.spk_dict_mean[spk] = state["spk_dict_mean"][spk]
self.spk_dict_std = {}
for spk in state["spk_dict_std"]:
self.spk_dict_std[spk] = state["spk_dict_std"][spk]
self.spk_dict_count = state["spk_dict_count"]
return state
def to(self, device):
self = super(InputNormalization, self).to(device)
self.glob_mean = self.glob_mean.to(device)
self.glob_std = self.glob_std.to(device)
for spk in self.spk_dict_mean:
self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device)
self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device)
return self
@mark_as_saver
def _save(self, path):
torch.save(self._statistics_dict(), path)
@mark_as_transfer
@mark_as_loader
def _load(self, path, end_of_epoch=False):
del end_of_epoch
stats = torch.load(path, map_location="cpu")
self._load_statistics_dict(stats)