import torch import torch.nn as nn from torch.nn.functional import mse_loss class GANLoss(nn.Module): def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor): super(GANLoss, self).__init__() self.real_label = target_real_label self.fake_label = target_fake_label self.real_label_var = None self.fake_label_var = None self.Tensor = tensor if use_lsgan: self.loss = nn.MSELoss() else: self.loss = nn.BCELoss() def get_target_tensor(self, input, target_is_real): target_tensor = None if target_is_real: create_label = ((self.real_label_var is None) or(self.real_label_var.numel() != input.numel())) # pdb.set_trace() if create_label: real_tensor = self.Tensor(input.size()).fill_(self.real_label) # self.real_label_var = Variable(real_tensor, requires_grad=False) # self.real_label_var = torch.Tensor(real_tensor) self.real_label_var = real_tensor target_tensor = self.real_label_var else: # pdb.set_trace() create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())) if create_label: fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) # self.fake_label_var = Variable(fake_tensor, requires_grad=False) # self.fake_label_var = torch.Tensor(fake_tensor) self.fake_label_var = fake_tensor target_tensor = self.fake_label_var return target_tensor def __call__(self, input, target_is_real): target_tensor = self.get_target_tensor(input, target_is_real) # pdb.set_trace() return self.loss(input, target_tensor)