""" Code adapted from https://github.com/mseitzer/pytorch-fid/ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models try: from torchvision.models.utils import load_state_dict_from_url except ImportError: from torch.utils.model_zoo import load_url as load_state_dict_from_url # Inception weights ported to Pytorch from # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' class InceptionV3FeatureExtractor(nn.Module): """Pretrained InceptionV3 network returning feature maps""" # Index of default block of inception to return, # corresponds to output of final average pooling DEFAULT_BLOCK_INDEX = 3 # Maps feature dimensionality to their output blocks indices BLOCK_INDEX_BY_DIM = { 64: 0, # First max pooling features 192: 1, # Second max pooling featurs 768: 2, # Pre-aux classifier features 2048: 3 # Final average pooling features } def __init__(self, output_block=DEFAULT_BLOCK_INDEX, pixel_min=-1, pixel_max=1): """ Build pretrained InceptionV3 Arguments: output_block (int): Index of block to return features of. Possible values are: - 0: corresponds to output of first max pooling - 1: corresponds to output of second max pooling - 2: corresponds to output which is fed to aux classifier - 3: corresponds to output of final average pooling pixel_min (float): Min value for inputs. Default value is -1. pixel_max (float): Max value for inputs. Default value is 1. """ super(InceptionV3FeatureExtractor, self).__init__() assert 0 <= output_block <= 3, '`output_block` can only be ' + \ '0 <= `output_block` <= 3.' inception = fid_inception_v3() blocks = [] # Block 0: input to maxpool1 block0 = [ inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3, nn.MaxPool2d(kernel_size=3, stride=2) ] blocks.append(nn.Sequential(*block0)) # Block 1: maxpool1 to maxpool2 if output_block >= 1: block1 = [ inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2) ] blocks.append(nn.Sequential(*block1)) # Block 2: maxpool2 to aux classifier if output_block >= 2: block2 = [ inception.Mixed_5b, inception.Mixed_5c, inception.Mixed_5d, inception.Mixed_6a, inception.Mixed_6b, inception.Mixed_6c, inception.Mixed_6d, inception.Mixed_6e, ] blocks.append(nn.Sequential(*block2)) # Block 3: aux classifier to final avgpool if output_block >= 3: block3 = [ inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, nn.AdaptiveAvgPool2d(output_size=(1, 1)) ] blocks.append(nn.Sequential(*block3)) self.main = nn.Sequential(*blocks) self.pixel_nin = pixel_min self.pixel_max = pixel_max self.requires_grad_(False) self.eval() def _scale(self, x): if self.pixel_min != -1 or self.pixel_max != 1: x = (2*x - self.pixel_min - self.pixel_max) \ / (self.pixel_max - self.pixel_min) return x def forward(self, input): """ Get Inception feature maps. Arguments: input (torch.Tensor) Returns: feature_maps (torch.Tensor) """ return self.main(input) def fid_inception_v3(): """Build pretrained Inception model for FID computation The Inception model for FID computation uses a different set of weights and has a slightly different structure than torchvision's Inception. This method first constructs torchvision's Inception and then patches the necessary parts that are different in the FID Inception model. """ inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False) inception.Mixed_5b = FIDInceptionA(192, pool_features=32) inception.Mixed_5c = FIDInceptionA(256, pool_features=64) inception.Mixed_5d = FIDInceptionA(288, pool_features=64) inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) inception.Mixed_7b = FIDInceptionE_1(1280) inception.Mixed_7c = FIDInceptionE_2(2048) state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) inception.load_state_dict(state_dict) return inception class FIDInceptionA(models.inception.InceptionA): """InceptionA block patched for FID computation""" def __init__(self, in_channels, pool_features): super(FIDInceptionA, self).__init__(in_channels, pool_features) def forward(self, x): branch1x1 = self.branch1x1(x) branch5x5 = self.branch5x5_1(x) branch5x5 = self.branch5x5_2(branch5x5) branch3x3dbl = self.branch3x3dbl_1(x) branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) # Patch: Tensorflow's average pool does not use the padded zero's in # its average calculation branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] return torch.cat(outputs, 1) class FIDInceptionC(models.inception.InceptionC): """InceptionC block patched for FID computation""" def __init__(self, in_channels, channels_7x7): super(FIDInceptionC, self).__init__(in_channels, channels_7x7) def forward(self, x): branch1x1 = self.branch1x1(x) branch7x7 = self.branch7x7_1(x) branch7x7 = self.branch7x7_2(branch7x7) branch7x7 = self.branch7x7_3(branch7x7) branch7x7dbl = self.branch7x7dbl_1(x) branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) # Patch: Tensorflow's average pool does not use the padded zero's in # its average calculation branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] return torch.cat(outputs, 1) class FIDInceptionE_1(models.inception.InceptionE): """First InceptionE block patched for FID computation""" def __init__(self, in_channels): super(FIDInceptionE_1, self).__init__(in_channels) def forward(self, x): branch1x1 = self.branch1x1(x) branch3x3 = self.branch3x3_1(x) branch3x3 = [ self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3), ] branch3x3 = torch.cat(branch3x3, 1) branch3x3dbl = self.branch3x3dbl_1(x) branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) branch3x3dbl = [ self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl), ] branch3x3dbl = torch.cat(branch3x3dbl, 1) # Patch: Tensorflow's average pool does not use the padded zero's in # its average calculation branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] return torch.cat(outputs, 1) class FIDInceptionE_2(models.inception.InceptionE): """Second InceptionE block patched for FID computation""" def __init__(self, in_channels): super(FIDInceptionE_2, self).__init__(in_channels) def forward(self, x): branch1x1 = self.branch1x1(x) branch3x3 = self.branch3x3_1(x) branch3x3 = [ self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3), ] branch3x3 = torch.cat(branch3x3, 1) branch3x3dbl = self.branch3x3dbl_1(x) branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) branch3x3dbl = [ self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl), ] branch3x3dbl = torch.cat(branch3x3dbl, 1) # Patch: The FID Inception model uses max pooling instead of average # pooling. This is likely an error in this specific Inception # implementation, as other Inception models use average pooling here # (which matches the description in the paper). branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] return torch.cat(outputs, 1)