# Copyright 2024 The YourMT3 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Please see the details in the LICENSE file. """pitchshift.py""" # import math import numpy as np # from scipy import special from einops import rearrange from typing import Optional, Literal, Dict, List, Tuple, Callable import torch from torch import nn import torchaudio from torchaudio import transforms # from torchaudio import functional as F # from torchaudio.functional.functional import ( # _fix_waveform_shape, # _stretch_waveform, # ) # from model.ops import adjust_b_to_gcd, check_all_elements_equal class PitchShiftLayer(nn.Module): """Applying batch-wise pitch-shift to time-domain audio signals. Args: pshift_range (List[int]): Range of pitch shift in semitones. Default: ``[-2, 2]``. resample_source_fs (int): Default is 4000. stretch_n_fft (int): Default is 2048. window: (Optional[Literal['kaiser']]) Default is None. beta: (Optional[float]): Parameter for 'kaiser' filter. Default: None. """ def __init__( self, pshift_range: List[int] = [-2, 2], resample_source_fs: int = 4000, strecth_n_fft: int = 512, win_length: Optional[int] = None, hop_length: Optional[int] = None, window: Optional[Literal['kaiser']] = None, beta: Optional[float] = None, expected_input_shape: Optional[Tuple[int]] = None, device: Optional[torch.device] = None, **kwargs, ) -> None: super().__init__() self.pshift_range = pshift_range self.resample_source_fs = resample_source_fs self.strecth_n_fft = strecth_n_fft self.win_length = win_length self.hop_length = hop_length if window is None: self.window_fn = torch.hann_window self.window_kwargs = None elif 'kaiser' in window: def custom_kaiser_window(window_length, beta, **kwargs): return torch.kaiser_window(window_length, periodic=True, beta=beta, **kwargs) self.window_fn = custom_kaiser_window self.window_kwargs = {'beta': beta} # Initialize pitch shifters for every semitone self.pshifters = None self.frame_gaps = None self._initialize_pshifters(expected_input_shape, device=device) self.requires_grad_(False) def _initialize_pshifters(self, expected_input_shape: Optional[Tuple[int]] = None, device: Optional[torch.device] = None) -> None: # DDP requires initializing parameters with a dummy input if expected_input_shape is not None: if device is not None: dummy_input = torch.randn(expected_input_shape, requires_grad=False).to(device) else: dummy_input = torch.randn(expected_input_shape, requires_grad=False) else: dummy_input = None pshifters = nn.ModuleDict() for semitone in range(self.pshift_range[0], self.pshift_range[1] + 1): if semitone == 0: # No need to shift and resample pshifters[str(semitone)] = None else: pshifter = transforms.PitchShift(self.resample_source_fs, n_steps=semitone, n_fft=self.strecth_n_fft, win_length=self.win_length, hop_length=self.hop_length, window_fn=self.window_fn, wkwargs=self.window_kwargs) pshifters[str(semitone)] = pshifter # Pass dummy input to initialize parameters with torch.no_grad(): if dummy_input is not None: _ = pshifter.initialize_parameters(dummy_input) self.pshifters = pshifters def calculate_frame_gaps(self) -> Dict[int, float]: """Calculate the expected gap between the original and the stretched audio.""" frame_gaps = {} # for debugging for semitone in range(self.pshift_range[0], self.pshift_range[1] + 1): if semitone == 0: # No need to shift and resample frame_gaps[semitone] = 0. else: pshifter = self.pshifters[str(semitone)] gap_in_ms = 1000. * (pshifter.kernel.shape[2] - pshifter.kernel.shape[0] / 2.0**(-float(semitone) / 12)) / self.resample_source_fs frame_gaps[semitone] = gap_in_ms return frame_gaps @torch.no_grad() def forward(self, x: torch.Tensor, semitone: int) -> torch.Tensor: """ Args: x (torch.Tensor): (B, 1, T) or (B, T) Returns: torch.Tensor: (B, 1, T) or (B, T) """ if semitone == 0: return x elif semitone >= min(self.pshift_range) and semitone <= max(self.pshift_range): return self.pshifters[str(semitone)](x) else: raise ValueError(f"semitone must be in range {self.pshift_range}") def test_resampler_sinewave(): # x: {440Hz, 220Hz} sine wave at 16kHz t = torch.arange(0, 2, 1 / 16000) # 2 seconds at 16kHz x0 = torch.sin(2 * torch.pi * 440 * t) * 0.5 x1 = torch.sin(2 * torch.pi * 220 * t) * 0.5 x = torch.stack((x0, x1), dim=0) # (2, 32000) # Resample psl = PitchShiftLayer(pshift_range=[-2, 2], resample_source_fs=4000) y = psl(x, 2) # (2, 24000) # Export to wav torchaudio.save("x.wav", x, 16000, bits_per_sample=16) torchaudio.save("y.wav", y, 12000, bits_per_sample=16) # class Resampler(nn.Module): # """ # Resampling using conv1d operations, more memory-efficient than torchaudio's resampler. # Based on Dan Povey's resampler.py: # https://github.com/danpovey/filtering/blob/master/lilfilter/resampler.py # """ # def __init__(self, # input_sr: int, # output_sr: int, # dtype: torch.dtype = torch.float32, # filter_width: int = 16, # cutoff_ratio: float = 0.85, # filter: Literal['kaiser', 'kaiser_best', 'kaiser_fast', 'hann'] = 'kaiser_fast', # beta: float = 8.555504641634386) -> None: # super().__init__() # init the base class # """ # Initialize the Resampler. # Args: # - input_sr (int): Input sampling rate. # - output_sr (int): Output sampling rate. # - dtype (torch.dtype): Computation data type. Default: torch.float32. # - filter_width (int): Number of zeros per side in the sinc function. Default: 16. # - cutoff_ratio (float): Filter rolloff point as a fraction of Nyquist freq. Default: 0.95. # - filter (str): Filter type. One of ['kaiser', 'kaiser_best', 'kaiser_fast', 'hann']. Default: 'kaiser_fast'. # - beta (float): Parameter for 'kaiser' filter. Default: 8.555504641634386. # Note: Ratio between input_sr and output_sr should be reduced to simplest form. # """ # assert isinstance(input_sr, int) and isinstance(output_sr, int) # if input_sr == output_sr: # self.resample_type = 'trivial' # return # d = math.gcd(input_sr, output_sr) # input_sr, output_sr = input_sr // d, output_sr // d # assert dtype in [torch.float32, torch.float64] # assert filter_width > 3 # a reasonable bare minimum # np_dtype = np.float32 if dtype == torch.float32 else np.float64 # assert filter in ['hann', 'kaiser', 'kaiser_best', 'kaiser_fast'] # if filter == 'kaiser_best': # filter_width = 64 # beta = 14.769656459379492 # cutoff_ratio = 0.9475937167399596 # filter = 'kaiser' # elif filter == 'kaiser_fast': # filter_width = 16 # beta = 8.555504641634386 # cutoff_ratio = 0.85 # filter = 'kaiser' # """ # - Define a sample 'block' correlating `input_sr` input samples to `output_sr` output samples. # - Dividing samples into these blocks allows corresponding block alignment. # - On average, `zeros_per_block` zeros per block are present in the sinc function. # """ # zeros_per_block = min(input_sr, output_sr) * cutoff_ratio # """ # - Define conv kernel size n = (blocks_per_side*2 + 1), adding blocks to each side of the center. # - `blocks_per_side` blocks as window radius ensures each central block sample accesses its window. # - `blocks_per_side` is determined, rounding up if needed, as 1 + int(filter_width / zeros_per_block). # """ # blocks_per_side = int(np.ceil(filter_width / zeros_per_block)) # kernel_width = 2 * blocks_per_side + 1 # # Shape of conv1d weights: (out_channels, in_channels, kernel_width) # """ Time computations are in units of 1 block, aligning with the `canonical` time axis, # since each block has input_sr input samples, adhering to our time unit.""" # window_radius_in_blocks = blocks_per_side # """`times` will be sinc function arguments, expanding to shape (output_sr, input_sr, kernel_width) # via broadcasting. Ensuring t == 0 along the central block diagonal (when input_sr == output_sr)""" # times = ( # np.arange(output_sr, dtype=np_dtype).reshape( # (output_sr, 1, 1)) / output_sr - np.arange(input_sr, dtype=np_dtype).reshape( # (1, input_sr, 1)) / input_sr - (np.arange(kernel_width, dtype=np_dtype).reshape( # (1, 1, kernel_width)) - blocks_per_side)) # def hann_window(a): # """ # returning 0.5 + 0.5 cos(a*pi) on [-1,1] and 0 outside. # """ # return np.heaviside(1 - np.abs(a), 0.0) * (0.5 + 0.5 * np.cos(a * np.pi)) # def kaiser_window(a, beta): # w = special.i0(beta * np.sqrt(np.clip(1 - ( # (a - 0.0) / 1.0)**2.0, 0.0, 1.0))) / special.i0(beta) # return np.heaviside(1 - np.abs(a), 0.0) * w # """The weights are computed as a sinc function times a Hann-window function, normalized by # `zeros_per_block` (sinc) and `input_sr` (input function) to maintain integral and magnitude.""" # if filter == 'hann': # weights = ( # np.sinc(times * zeros_per_block) * hann_window(times / window_radius_in_blocks) * # zeros_per_block / input_sr) # else: # weights = ( # np.sinc(times * zeros_per_block) * # kaiser_window(times / window_radius_in_blocks, beta) * zeros_per_block / input_sr) # self.input_sr = input_sr # self.output_sr = output_sr # """If output_sr == 1, merge input_sr into kernel_width for weights (shape: output_sr, input_sr, # kernel_width) to optimize convolution speed and avoid extra reshaping.""" # assert weights.shape == (output_sr, input_sr, kernel_width) # if output_sr == 1: # self.resample_type = 'integer_downsample' # self.padding = input_sr * blocks_per_side # weights = torch.tensor(weights, dtype=dtype, requires_grad=False) # weights = weights.transpose(1, 2).contiguous().view(1, 1, input_sr * kernel_width) # elif input_sr == 1: # # For conv_transpose, use weights as if input_sr and output_sr were swapped, simulating downsampling. # self.resample_type = 'integer_upsample' # self.padding = output_sr * blocks_per_side # weights = torch.tensor(weights, dtype=dtype, requires_grad=False) # weights = weights.flip(2).transpose(0, # 2).contiguous().view(1, 1, output_sr * kernel_width) # else: # self.resample_type = 'general' # self.reshaped = False # self.padding = blocks_per_side # weights = torch.tensor(weights, dtype=dtype, requires_grad=False) # self.weights = torch.nn.Parameter(weights, requires_grad=False) # @torch.no_grad() # def forward(self, x: torch.Tensor) -> torch.Tensor: # """ # Parameters: # - x: torch.Tensor, with shape (minibatch_size, sequence_length), dtype should match the instance's dtype. # Returns: # - A torch.Tensor with shape (minibatch_size, (sequence_length//input_sr)*output_sr), dtype matching the input, # and content resampled. # """ # if self.resample_type == 'trivial': # return x # elif self.resample_type == 'integer_downsample': # (minibatch_size, seq_len) = x.shape # (B, in_C, L) with in_C == 1 # x = x.unsqueeze(1) # x = torch.nn.functional.conv1d( # x, self.weights, stride=self.input_sr, padding=self.padding) # (B, out_C, L) # return x.squeeze(1) # (B, L) # elif self.resample_type == 'integer_upsample': # x = x.unsqueeze(1) # x = torch.nn.functional.conv_transpose1d( # x, self.weights, stride=self.output_sr, padding=self.padding) # return x.squeeze(1) # else: # assert self.resample_type == 'general' # (minibatch_size, seq_len) = x.shape # num_blocks = seq_len // self.input_sr # if num_blocks == 0: # # TODO: pad with zeros. # raise RuntimeError("Signal is too short to resample") # # Truncate input # x = x[:, 0:(num_blocks * self.input_sr)].view(minibatch_size, num_blocks, self.input_sr) # x = x.transpose(1, 2) # (B, in_C, L) # x = torch.nn.functional.conv1d( # x, self.weights, padding=self.padding) # (B, out_C, num_blocks) # return x.transpose(1, 2).contiguous().view(minibatch_size, num_blocks * self.output_sr) # def test_resampler_sinewave(): # import torchaudio # # x: {440Hz, 220Hz} sine wave at 16kHz # t = torch.arange(0, 2, 1 / 16000) # 2 seconds at 16kHz # x0 = torch.sin(2 * torch.pi * 440 * t) * 0.5 # x1 = torch.sin(2 * torch.pi * 220 * t) * 0.5 # x = torch.stack((x0, x1), dim=0) # (2, 32000) # # Resample # resampler = Resampler(input_sr=16000, output_sr=12000) # y = resampler(x) # (2, 24000) # # Export to wav # torchaudio.save("x.wav", x, 16000, bits_per_sample=16) # torchaudio.save("y.wav", y, 12000, bits_per_sample=16) # def test_resampler_music(): # import torchaudio # # x: music at 16kHz # x, _ = torchaudio.load("music.wav") # slice_length = 32000 # n_slices = 80 # slices = [x[0, i * slice_length:(i + 1) * slice_length] for i in range(n_slices)] # x = torch.stack(slices) # (80, 32000) # # Resample # filter_width = 32 # resampler = Resampler(16000, 12000, filter_width=filter_width) # y = resampler(x) # (80, 24000) # y = y.reshape(1, -1) # (1, 1920000) # torchaudio.save(f"y_filter_width{filter_width}.wav", y, 12000, bits_per_sample=16) # class PitchShiftLayer(nn.Module): # """Applying batch-wise pitch-shift to time-domain audio signals. # Args: # expected_input_length (int): Expected input length. Default: ``32767``. # pshift_range (List[int]): Range of pitch shift in semitones. Default: ``[-2, 2]``. # min_gcd (int): Minimum GCD of input and output sampling rates for resampling. Setting high value can save GPU memory. Default: ``16``. # max_timing_error (float): Maximum allowed timing error in seconds. Default: ``0.002``. # fs (int): Sample rate of input waveform, x. Default: 16000. # bins_per_octave (int, optional): The number of steps per octave (Default : ``12``). # n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins (Default: ``512``). # win_length (int or None, optional): Window size. If None, then ``n_fft`` is used. (Default: ``None``). # hop_length (int or None, optional): Length of hop between STFT windows. If None, then ``win_length // 4`` # is used (Default: ``None``). # window (Tensor or None, optional): Window tensor that is applied/multiplied to each frame/window. # If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``). # """ # def __init__( # self, # expected_input_length: int = 32767, # pshift_range: List[int] = [-2, 2], # min_gcd: int = 16, # max_timing_error: float = 0.002, # fs: int = 16000, # bins_per_octave: int = 12, # n_fft: int = 2048, # win_length: Optional[int] = None, # hop_length: Optional[int] = None, # window: Optional[torch.Tensor] = None, # filter_width: int = 16, # filter: Literal['kaiser', 'kaiser_best', 'kaiser_fast', 'hann'] = 'kaiser_fast', # cutoff_ratio: float = 0.85, # beta: float = 8.555504641634386, # **kwargs, # ): # super().__init__() # self.expected_input_length = expected_input_length # self.pshift_range = pshift_range # self.min_gcd = min_gcd # self.max_timing_error = max_timing_error # self.fs = fs # self.bins_per_octave = bins_per_octave # self.n_fft = n_fft # self.win_length = win_length # self.hop_length = hop_length # self.window = window # self.resample_args = { # "filter_width": filter_width, # "filter": filter, # "cutoff_ratio": cutoff_ratio, # "beta": beta, # } # # Initialize Resamplers # self._initialize_resamplers() # def _initialize_resamplers(self): # resamplers = nn.ModuleDict() # self.frame_gaps = {} # for debugging # for i in range(self.pshift_range[0], self.pshift_range[1] + 1): # if i == 0: # # No need to shift and resample # resamplers[str(i)] = None # else: # # Find optimal reconversion frames meeting the min_gcd # stretched_frames, recon_frames, gap = self._find_optimal_reconversion_frames(i) # self.frame_gaps[i] = gap # resamplers[str(i)] = Resampler(stretched_frames, recon_frames, **self.resample_args) # self.resamplers = resamplers # def _find_optimal_reconversion_frames(self, semitone: int): # """ # Find the optimal reconversion frames for a given source sample rate, input length, and semitone for strech. # Parameters: # - sr (int): Input audio sample rate, which should be power of 2 # - n_step (int): The number of pitch-shift steps in semi-tone. # - min_gcd (int): The minimum desired GCD, power of 2. Defaults to 16. 16 or 32 are good choices. # - max_timing_error (float): The maximum allowed timing error, in seconds. Defaults to 5 ms # Returns: # - int: The optimal target sample rate # """ # stretch_rate = 1 / 2.0**(-float(semitone) / self.bins_per_octave) # stretched_frames = round(self.expected_input_length * stretch_rate) # gcd = math.gcd(self.expected_input_length, stretched_frames) # if gcd >= self.min_gcd: # return stretched_frames, self.expected_input_length, 0 # else: # reconversion_frames = adjust_b_to_gcd(stretched_frames, self.expected_input_length, # self.min_gcd) # gap = reconversion_frames - self.expected_input_length # gap_sec = gap / self.fs # if gap_sec > self.max_timing_error: # # TODO: modifying vocoder of stretch_waveform to adjust pitch-shift rate in cents # raise ValueError( # gap_sec < self.max_timing_error, # f"gap_sec={gap_sec} > max_timing_error={self.max_timing_error} with semitone={semitone}, stretched_frames={stretched_frames}, recon_frames={reconversion_frames}. Try adjusting input lenght or decreasing min_gcd." # ) # else: # return stretched_frames, reconversion_frames, gap_sec # @torch.no_grad() # def forward(self, # x: torch.Tensor, # semitone: int, # resample: bool = True, # fix_shape: bool = True) -> torch.Tensor: # """ # Args: # x (torch.Tensor): (B, 1, T) # Returns: # torch.Tensor: (B, 1, T) # """ # if semitone == 0: # return x # elif semitone >= min(self.pshift_range) and semitone <= max(self.pshift_range): # x = x.squeeze(1) # (B, T) # original_x_size = x.size() # x = _stretch_waveform( # x, # semitone, # self.bins_per_octave, # self.n_fft, # self.win_length, # self.hop_length, # self.window, # ) # if resample: # x = self.resamplers[str(semitone)].forward(x) # # Fix waveform shape # if fix_shape: # if x.size(1) != original_x_size[1]: # # print(f"Warning: {x.size(1)} != {original_x_length}") # x = _fix_waveform_shape(x, original_x_size) # return x.unsqueeze(1) # (B, 1, T) # else: # raise ValueError(f"semitone must be in range {self.pshift_range}") # def test_pitchshift_layer(): # import torchaudio # # music # # x, _ = torchaudio.load("music.wav") # # slice_length = 32767 # # n_slices = 80 # # slices = [x[0, i * slice_length:(i + 1) * slice_length] for i in range(n_slices)] # # x = torch.stack(slices).unsqueeze(1) # (80, 1, 32767) # # sine wave # t = torch.arange(0, 2.0479, 1 / 16000) # 2.05 seconds at 16kHz # x = torch.sin(2 * torch.pi * 440 * t) * 0.5 # x = x.reshape(1, 1, 32767).tile(80, 1, 1) # # Resample # pos = 0 # ps = PitchShiftLayer( # pshift_range=[-3, 4], # expected_input_length=32767, # fs=16000, # min_gcd=16, # max_timing_error=0.002, # # filter_width=64, # filter='kaiser_fast', # n_fft=2048) # y = [] # for i in range(-3, 4): # y.append(ps(x[[pos], :, :], i, resample=False, fix_shape=False)[0, 0, :]) # y = torch.cat(y).unsqueeze(0) # (1, 32767 * 7) # torchaudio.save("y_2048_kaiser_fast.wav", y, 16000, bits_per_sample=16) # # TorchAudio PitchShifter fopr comparision # y_ta = [] # for i in range(-3, 4): # ta_transform = torchaudio.transforms.PitchShift(16000, n_steps=i) # y_ta.append(ta_transform(x[[pos], :, :])[0, 0, :]) # y_ta = torch.cat(y_ta).unsqueeze(0) # (1, 32767 * 7) # torchaudio.save("y_ta.wav", y_ta, 16000, bits_per_sample=16) # def test_min_gcd_mem_usage(): # min_gcd = 16 # for i in range(-3, 4): # stretched_frames = _stretch_waveform(x, i).shape[1] # adjusted = adjust_b_to_gcd(stretched_frames, 32767, min_gcd) # gcd_val = math.gcd(adjusted, stretched_frames) # gap = adjusted - 32767 # gap_ms = (gap / 16000) * 1000 # mem_mb = (stretched_frames / gcd_val) * (adjusted / gcd_val) * 3 * 4 / 1000 / 1000 # print(f'\033[92mmin_gcd={min_gcd}\033[0m', f'ps={i}', f'frames={stretched_frames}', # f'adjusted_frames={adjusted}', f'gap={gap}', f'\033[91mgap_ms={gap_ms}\033[0m', # f'gcd={gcd_val}', f'mem_MB={mem_mb}')