Spaces:
Runtime error
Runtime error
File size: 429 Bytes
2e34814 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import torch
import torchvision
l2_criterion = torch.nn.MSELoss(reduction="mean")
def l2_loss(real_images, generated_images, gray=False):
if gray:
real_images = torchvision.transforms.functional.rgb_to_grayscale(real_images)
generated_images = torchvision.transforms.functional.rgb_to_grayscale(
generated_images
)
loss = l2_criterion(real_images, generated_images)
return loss
|