anyantudre's picture
moved from training repo to inference
caa56d6
raw
history blame contribute delete
469 Bytes
import torch.nn as nn
from .abstract_loss_func import AbstractLossClass
from metrics.registry import LOSSFUNC
@LOSSFUNC.register_module(module_name="l1loss")
class L1Loss(AbstractLossClass):
def __init__(self):
super().__init__()
self.loss_fn = nn.L1Loss()
def forward(self, inputs, targets):
"""
Computes the l1 loss.
"""
# Compute the l1 loss
loss = self.loss_fn(inputs, targets)
return loss