import torch import torchvision l2_criterion = torch.nn.MSELoss(reduction="mean") def l2_loss(real_images, generated_images, gray=False): if gray: real_images = torchvision.transforms.functional.rgb_to_grayscale(real_images) generated_images = torchvision.transforms.functional.rgb_to_grayscale( generated_images ) loss = l2_criterion(real_images, generated_images) return loss