Spaces:
Runtime error
Runtime error
import torch | |
import torchvision | |
l2_criterion = torch.nn.MSELoss(reduction="mean") | |
def l2_loss(real_images, generated_images, gray=False): | |
if gray: | |
real_images = torchvision.transforms.functional.rgb_to_grayscale(real_images) | |
generated_images = torchvision.transforms.functional.rgb_to_grayscale( | |
generated_images | |
) | |
loss = l2_criterion(real_images, generated_images) | |
return loss | |