File size: 1,074 Bytes
5d58b52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
# -*- coding: UTF-8 -*-
import torch
from torch import nn
# https://github.com/Mikoto10032/AutomaticWeightedLoss/blob/master/AutomaticWeightedLoss.py
class AutomaticWeightedLoss(nn.Module):
# '''
# automatically weighted multi-task loss
# Params��
# num: int��the number of loss
# x: multi-task loss
# Examples��
# loss1=1
# loss2=2
# awl = AutomaticWeightedLoss(2)
# loss_sum = awl(loss1, loss2)
# '''
def __init__(self, num=2, args=None):
super(AutomaticWeightedLoss, self).__init__()
if args is None or args.use_awl:
params = torch.ones(num, requires_grad=True)
self.params = torch.nn.Parameter(params)
else:
params = torch.ones(num, requires_grad=False)
self.params = torch.nn.Parameter(params, requires_grad=False)
def forward(self, *x):
loss_sum = 0
for i, loss in enumerate(x):
loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2)
return loss_sum
|