File size: 429 Bytes
2e34814
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import torchvision

l2_criterion = torch.nn.MSELoss(reduction="mean")


def l2_loss(real_images, generated_images, gray=False):
    if gray:
        real_images = torchvision.transforms.functional.rgb_to_grayscale(real_images)
        generated_images = torchvision.transforms.functional.rgb_to_grayscale(
            generated_images
        )
    loss = l2_criterion(real_images, generated_images)
    return loss