Erasing-Invisible-Demo / kit /metrics /clean_fid /inception_torchscript.py
mcding
published version
ad552d8
import os
import torch
import torch.nn as nn
import contextlib
from .downloads_helper import *
@contextlib.contextmanager
def disable_gpu_fuser_on_pt19():
# On PyTorch 1.9 a CUDA fuser bug prevents the Inception JIT model to run. See
# https://github.com/GaParmar/clean-fid/issues/5
# https://github.com/pytorch/pytorch/issues/64062
if torch.__version__.startswith("1.9."):
old_val = torch._C._jit_can_fuse_on_gpu()
torch._C._jit_override_can_fuse_on_gpu(False)
yield
if torch.__version__.startswith("1.9."):
torch._C._jit_override_can_fuse_on_gpu(old_val)
class InceptionV3W(nn.Module):
"""
Wrapper around Inception V3 torchscript model provided here
https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt
path: locally saved inception weights
"""
def __init__(self, path, download=True, resize_inside=False):
super(InceptionV3W, self).__init__()
# download the network if it is not present at the given directory
# use the current directory by default
if download:
check_download_inception(fpath=path)
path = os.path.join(path, "inception-2015-12-05.pt")
self.base = torch.jit.load(path).eval()
self.layers = self.base.layers
self.resize_inside = resize_inside
"""
Get the inception features without resizing
x: Image with values in range [0,255]
"""
def forward(self, x):
with disable_gpu_fuser_on_pt19():
bs = x.shape[0]
if self.resize_inside:
features = self.base(x, return_features=True).view((bs, 2048))
else:
# make sure it is resized already
assert (x.shape[2] == 299) and (x.shape[3] == 299)
# apply normalization
x1 = x - 128
x2 = x1 / 128
features = self.layers.forward(
x2,
).view((bs, 2048))
return features