|
import argparse |
|
import pickle |
|
import os |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from torch.utils.data import DataLoader |
|
from torchvision import transforms |
|
from torchvision.models import inception_v3, Inception3 |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
from inception import InceptionV3 |
|
from torchvision.datasets import ImageFolder |
|
|
|
class Inception3Feature(Inception3): |
|
def forward(self, x): |
|
if x.shape[2] != 299 or x.shape[3] != 299: |
|
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) |
|
|
|
x = self.Conv2d_1a_3x3(x) |
|
x = self.Conv2d_2a_3x3(x) |
|
x = self.Conv2d_2b_3x3(x) |
|
x = F.max_pool2d(x, kernel_size=3, stride=2) |
|
|
|
x = self.Conv2d_3b_1x1(x) |
|
x = self.Conv2d_4a_3x3(x) |
|
x = F.max_pool2d(x, kernel_size=3, stride=2) |
|
|
|
x = self.Mixed_5b(x) |
|
x = self.Mixed_5c(x) |
|
x = self.Mixed_5d(x) |
|
|
|
x = self.Mixed_6a(x) |
|
x = self.Mixed_6b(x) |
|
x = self.Mixed_6c(x) |
|
x = self.Mixed_6d(x) |
|
x = self.Mixed_6e(x) |
|
|
|
x = self.Mixed_7a(x) |
|
x = self.Mixed_7b(x) |
|
x = self.Mixed_7c(x) |
|
|
|
x = F.avg_pool2d(x, kernel_size=8) |
|
|
|
return x.view(x.shape[0], x.shape[1]) |
|
|
|
|
|
def load_patched_inception_v3(): |
|
|
|
|
|
|
|
inception_feat = InceptionV3([3], normalize_input=False) |
|
|
|
return inception_feat |
|
|
|
|
|
@torch.no_grad() |
|
def extract_features(loader, inception, device): |
|
pbar = tqdm(loader) |
|
|
|
feature_list = [] |
|
|
|
for img,_ in pbar: |
|
img = img.to(device) |
|
feature = inception(img)[0].view(img.shape[0], -1) |
|
feature_list.append(feature.to('cpu')) |
|
|
|
features = torch.cat(feature_list, 0) |
|
|
|
return features |
|
|
|
|
|
if __name__ == '__main__': |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
parser = argparse.ArgumentParser( |
|
description='Calculate Inception v3 features for datasets' |
|
) |
|
parser.add_argument('--size', type=int, default=256) |
|
parser.add_argument('--batch', default=64, type=int, help='batch size') |
|
parser.add_argument('--n_sample', type=int, default=50000) |
|
parser.add_argument('--flip', action='store_true') |
|
parser.add_argument('path', metavar='PATH', help='path to datset lmdb file') |
|
|
|
args = parser.parse_args() |
|
|
|
inception = load_patched_inception_v3().eval().to(device) |
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize( (args.size, args.size) ), |
|
transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
|
] |
|
) |
|
|
|
dset = ImageFolder(args.path, transform) |
|
loader = DataLoader(dset, batch_size=args.batch, num_workers=4) |
|
|
|
features = extract_features(loader, inception, device).numpy() |
|
|
|
features = features[: args.n_sample] |
|
|
|
print(f'extracted {features.shape[0]} features') |
|
|
|
mean = np.mean(features, 0) |
|
cov = np.cov(features, rowvar=False) |
|
|
|
name = os.path.splitext(os.path.basename(args.path))[0] |
|
|
|
print({'mean': mean.mean(), 'cov': cov.mean()}) |
|
with open(f'inception_{name}.pkl', 'wb') as f: |
|
pickle.dump({'mean': mean, 'cov': cov, 'size': args.size, 'path': args.path}, f) |
|
|