import re import torch.nn as nn class BaseObject(nn.Module): def __init__(self, name=None): super().__init__() self._name = name @property def __name__(self): if self._name is None: name = self.__class__.__name__ s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() else: return self._name class Metric(BaseObject): pass class Loss(BaseObject): def __add__(self, other): if isinstance(other, Loss): return SumOfLosses(self, other) else: raise ValueError("Loss should be inherited from `Loss` class") def __radd__(self, other): return self.__add__(other) def __mul__(self, value): if isinstance(value, (int, float)): return MultipliedLoss(self, value) else: raise ValueError("Loss should be inherited from `BaseLoss` class") def __rmul__(self, other): return self.__mul__(other) class SumOfLosses(Loss): def __init__(self, l1, l2): name = "{} + {}".format(l1.__name__, l2.__name__) super().__init__(name=name) self.l1 = l1 self.l2 = l2 def __call__(self, *inputs): return self.l1.forward(*inputs) + self.l2.forward(*inputs) class MultipliedLoss(Loss): def __init__(self, loss, multiplier): # resolve name if len(loss.__name__.split("+")) > 1: name = "{} * ({})".format(multiplier, loss.__name__) else: name = "{} * {}".format(multiplier, loss.__name__) super().__init__(name=name) self.loss = loss self.multiplier = multiplier def __call__(self, *inputs): return self.multiplier * self.loss.forward(*inputs)