Spaces:
Running
Running
import logging | |
from typing import Any, KeysView | |
from torch import Tensor, nn | |
logger = logging.getLogger(__name__) | |
class MSELoss(nn.Module): | |
def __init__(self, cfg: Any): | |
super().__init__() | |
self.cfg = cfg | |
self.loss_fn = nn.MSELoss() | |
def forward(self, logits: Tensor, labels: Tensor) -> Tensor: | |
return self.loss_fn(logits, labels) | |
class MAELoss(nn.Module): | |
def __init__(self, cfg: Any): | |
super().__init__() | |
self.cfg = cfg | |
self.loss_fn = nn.L1Loss() | |
def forward(self, logits: Tensor, labels: Tensor) -> Tensor: | |
return self.loss_fn(logits, labels) | |
class Losses: | |
"""Losses factory.""" | |
_losses = { | |
"MSELoss": MSELoss, | |
"MAELoss": MAELoss, | |
} | |
def names(cls) -> KeysView: | |
return cls._losses.keys() | |
def get(cls, name: str) -> Any: | |
"""Access to Losses. | |
Args: | |
name: losses name | |
Returns: | |
A class to build the Losses | |
""" | |
return cls._losses.get(name, MSELoss) | |