Spaces:
Runtime error
Runtime error
File size: 2,337 Bytes
75c6e9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import math
from typing import Callable
import torch
import torch.nn as nn
from torchlibrosa.stft import STFT
from bytesep.models.pytorch_modules import Base
def l1(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
r"""L1 loss.
Args:
output: torch.Tensor
target: torch.Tensor
Returns:
loss: torch.float
"""
return torch.mean(torch.abs(output - target))
def l1_wav(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
r"""L1 loss in the time-domain.
Args:
output: torch.Tensor
target: torch.Tensor
Returns:
loss: torch.float
"""
return l1(output, target)
class L1_Wav_L1_Sp(nn.Module, Base):
def __init__(self):
r"""L1 loss in the time-domain and L1 loss on the spectrogram."""
super(L1_Wav_L1_Sp, self).__init__()
self.window_size = 2048
hop_size = 441
center = True
pad_mode = "reflect"
window = "hann"
self.stft = STFT(
n_fft=self.window_size,
hop_length=hop_size,
win_length=self.window_size,
window=window,
center=center,
pad_mode=pad_mode,
freeze_parameters=True,
)
def __call__(
self, output: torch.Tensor, target: torch.Tensor, **kwargs
) -> torch.Tensor:
r"""L1 loss in the time-domain and on the spectrogram.
Args:
output: torch.Tensor
target: torch.Tensor
Returns:
loss: torch.float
"""
# L1 loss in the time-domain.
wav_loss = l1_wav(output, target)
# L1 loss on the spectrogram.
sp_loss = l1(
self.wav_to_spectrogram(output, eps=1e-8),
self.wav_to_spectrogram(target, eps=1e-8),
)
# sp_loss /= math.sqrt(self.window_size)
# sp_loss *= 1.
# Total loss.
return wav_loss + sp_loss
return sp_loss
def get_loss_function(loss_type: str) -> Callable:
r"""Get loss function.
Args:
loss_type: str
Returns:
loss function: Callable
"""
if loss_type == "l1_wav":
return l1_wav
elif loss_type == "l1_wav_l1_sp":
return L1_Wav_L1_Sp()
else:
raise NotImplementedError
|