File size: 1,820 Bytes
2a13495 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import re
import torch.nn as nn
class BaseObject(nn.Module):
def __init__(self, name=None):
super().__init__()
self._name = name
@property
def __name__(self):
if self._name is None:
name = self.__class__.__name__
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
else:
return self._name
class Metric(BaseObject):
pass
class Loss(BaseObject):
def __add__(self, other):
if isinstance(other, Loss):
return SumOfLosses(self, other)
else:
raise ValueError("Loss should be inherited from `Loss` class")
def __radd__(self, other):
return self.__add__(other)
def __mul__(self, value):
if isinstance(value, (int, float)):
return MultipliedLoss(self, value)
else:
raise ValueError("Loss should be inherited from `BaseLoss` class")
def __rmul__(self, other):
return self.__mul__(other)
class SumOfLosses(Loss):
def __init__(self, l1, l2):
name = "{} + {}".format(l1.__name__, l2.__name__)
super().__init__(name=name)
self.l1 = l1
self.l2 = l2
def __call__(self, *inputs):
return self.l1.forward(*inputs) + self.l2.forward(*inputs)
class MultipliedLoss(Loss):
def __init__(self, loss, multiplier):
# resolve name
if len(loss.__name__.split("+")) > 1:
name = "{} * ({})".format(multiplier, loss.__name__)
else:
name = "{} * {}".format(multiplier, loss.__name__)
super().__init__(name=name)
self.loss = loss
self.multiplier = multiplier
def __call__(self, *inputs):
return self.multiplier * self.loss.forward(*inputs)
|