Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class ContentLoss(nn.Module): | |
def __init__(self, target,): | |
super().__init__() | |
self.target = target.detach() | |
def forward(self, input): | |
self.loss = F.mse_loss(input, self.target) | |
return input | |
class StyleLoss(nn.Module): | |
def __init__(self, target_feature): | |
super().__init__() | |
self.target = self.gram_matrix(target_feature).detach() | |
def gram_matrix(self,input): | |
a, b, c, d = input.size() | |
features = input.view(a * b, c * d) | |
G = torch.mm(features, features.t()) | |
return G.div(a * b * c * d) | |
def forward(self, input): | |
G = self.gram_matrix(input) | |
self.loss = F.mse_loss(G, self.target) | |
return input | |
class Normalization(nn.Module): | |
def __init__(self, mean, std): | |
super().__init__() | |
self.mean = torch.tensor(mean).view(-1, 1, 1) | |
self.std = torch.tensor(std).view(-1, 1, 1) | |
def forward(self, img): | |
return (img - self.mean) / self.std |