File size: 3,041 Bytes
2f5f13b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch

if torch.cuda.is_available() and torch.cuda.get_device_name().endswith("[ZLUDA]"):

    class STFT:
        def __init__(self):
            self.device = "cuda"
            self.fourier_bases = {}  # Cache for Fourier bases

        def _get_fourier_basis(self, n_fft):
            # Check if the basis for this n_fft is already cached
            if n_fft in self.fourier_bases:
                return self.fourier_bases[n_fft]
            fourier_basis = torch.fft.fft(torch.eye(n_fft, device="cpu")).to(
                self.device
            )
            # stack separated real and imaginary components and convert to torch tensor
            cutoff = n_fft // 2 + 1
            fourier_basis = torch.cat(
                [fourier_basis.real[:cutoff], fourier_basis.imag[:cutoff]], dim=0
            )
            # cache the tensor and return
            self.fourier_bases[n_fft] = fourier_basis
            return fourier_basis

        def transform(self, input, n_fft, hop_length, window):
            # fetch cached Fourier basis
            fourier_basis = self._get_fourier_basis(n_fft)
            # apply hann window to Fourier basis
            fourier_basis = fourier_basis * window
            # pad input to center with reflect
            pad_amount = n_fft // 2
            input = torch.nn.functional.pad(
                input, (pad_amount, pad_amount), mode="reflect"
            )
            # separate input into n_fft-sized frames
            input_frames = input.unfold(1, n_fft, hop_length).permute(0, 2, 1)
            # apply fft to each frame
            fourier_transform = torch.matmul(fourier_basis, input_frames)
            cutoff = n_fft // 2 + 1
            return torch.complex(
                fourier_transform[:, :cutoff, :], fourier_transform[:, cutoff:, :]
            )

    stft = STFT()
    _torch_stft = torch.stft

    def z_stft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs):
        # only optimizing a specific call from rvc.train.mel_processing.MultiScaleMelSpectrogramLoss
        if (
            kwargs.get("win_length") == None
            and kwargs.get("center") == None
            and kwargs.get("return_complex") == True
        ):
            # use GPU accelerated calculation
            return stft.transform(
                input, kwargs.get("n_fft"), kwargs.get("hop_length"), window
            )
        else:
            # simply do the operation on CPU
            return _torch_stft(
                input=input.cpu(), window=window.cpu(), *args, **kwargs
            ).to(input.device)

    def z_jit(f, *_, **__):
        f.graph = torch._C.Graph()
        return f

    # hijacks
    torch.stft = z_stft
    torch.jit.script = z_jit
    # disabling unsupported cudnn
    torch.backends.cudnn.enabled = False
    torch.backends.cuda.enable_flash_sdp(False)
    torch.backends.cuda.enable_math_sdp(True)
    torch.backends.cuda.enable_mem_efficient_sdp(False)