echen01
add PTI
2e34814
raw
history blame
429 Bytes
import torch
import torchvision
l2_criterion = torch.nn.MSELoss(reduction="mean")
def l2_loss(real_images, generated_images, gray=False):
if gray:
real_images = torchvision.transforms.functional.rgb_to_grayscale(real_images)
generated_images = torchvision.transforms.functional.rgb_to_grayscale(
generated_images
)
loss = l2_criterion(real_images, generated_images)
return loss