soiz1's picture
Upload 204 files
2f5f13b verified
import torch
if torch.cuda.is_available() and torch.cuda.get_device_name().endswith("[ZLUDA]"):
class STFT:
def __init__(self):
self.device = "cuda"
self.fourier_bases = {} # Cache for Fourier bases
def _get_fourier_basis(self, n_fft):
# Check if the basis for this n_fft is already cached
if n_fft in self.fourier_bases:
return self.fourier_bases[n_fft]
fourier_basis = torch.fft.fft(torch.eye(n_fft, device="cpu")).to(
self.device
)
# stack separated real and imaginary components and convert to torch tensor
cutoff = n_fft // 2 + 1
fourier_basis = torch.cat(
[fourier_basis.real[:cutoff], fourier_basis.imag[:cutoff]], dim=0
)
# cache the tensor and return
self.fourier_bases[n_fft] = fourier_basis
return fourier_basis
def transform(self, input, n_fft, hop_length, window):
# fetch cached Fourier basis
fourier_basis = self._get_fourier_basis(n_fft)
# apply hann window to Fourier basis
fourier_basis = fourier_basis * window
# pad input to center with reflect
pad_amount = n_fft // 2
input = torch.nn.functional.pad(
input, (pad_amount, pad_amount), mode="reflect"
)
# separate input into n_fft-sized frames
input_frames = input.unfold(1, n_fft, hop_length).permute(0, 2, 1)
# apply fft to each frame
fourier_transform = torch.matmul(fourier_basis, input_frames)
cutoff = n_fft // 2 + 1
return torch.complex(
fourier_transform[:, :cutoff, :], fourier_transform[:, cutoff:, :]
)
stft = STFT()
_torch_stft = torch.stft
def z_stft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs):
# only optimizing a specific call from rvc.train.mel_processing.MultiScaleMelSpectrogramLoss
if (
kwargs.get("win_length") == None
and kwargs.get("center") == None
and kwargs.get("return_complex") == True
):
# use GPU accelerated calculation
return stft.transform(
input, kwargs.get("n_fft"), kwargs.get("hop_length"), window
)
else:
# simply do the operation on CPU
return _torch_stft(
input=input.cpu(), window=window.cpu(), *args, **kwargs
).to(input.device)
def z_jit(f, *_, **__):
f.graph = torch._C.Graph()
return f
# hijacks
torch.stft = z_stft
torch.jit.script = z_jit
# disabling unsupported cudnn
torch.backends.cudnn.enabled = False
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)