llm-studio / llm_studio /src /losses /text_causal_regression_modeling_losses.py
qinfeng722's picture
Upload 322 files
5caedb4 verified
import logging
from typing import Any, KeysView
from torch import Tensor, nn
logger = logging.getLogger(__name__)
class MSELoss(nn.Module):
def __init__(self, cfg: Any):
super().__init__()
self.cfg = cfg
self.loss_fn = nn.MSELoss()
def forward(self, logits: Tensor, labels: Tensor) -> Tensor:
return self.loss_fn(logits, labels)
class MAELoss(nn.Module):
def __init__(self, cfg: Any):
super().__init__()
self.cfg = cfg
self.loss_fn = nn.L1Loss()
def forward(self, logits: Tensor, labels: Tensor) -> Tensor:
return self.loss_fn(logits, labels)
class Losses:
"""Losses factory."""
_losses = {
"MSELoss": MSELoss,
"MAELoss": MAELoss,
}
@classmethod
def names(cls) -> KeysView:
return cls._losses.keys()
@classmethod
def get(cls, name: str) -> Any:
"""Access to Losses.
Args:
name: losses name
Returns:
A class to build the Losses
"""
return cls._losses.get(name, MSELoss)