File size: 1,990 Bytes
7e4e601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from numpy.lib.arraysetops import isin
import torchvision.models as models
import torch.nn as nn
import torch



class Vgg19(nn.Module):
    def __init__(self):
        super(Vgg19, self).__init__()
        self.vgg19 = self.get_vgg19().eval()
        vgg_mean = torch.tensor([0.485, 0.456, 0.406]).float()
        vgg_std = torch.tensor([0.229, 0.224, 0.225]).float()
        self.mean = vgg_mean.view(-1, 1 ,1)
        self.std = vgg_std.view(-1, 1, 1)

    def to(self, device):
        new_self = super(Vgg19, self).to(device)
        new_self.mean = new_self.mean.to(device)
        new_self.std = new_self.std.to(device)
        return new_self

    def forward(self, x):
        return self.vgg19(self.normalize_vgg(x))

    @staticmethod
    def get_vgg19(last_layer='conv4_4'):
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
        model_list = []

        i = 0
        j = 1
        for layer in vgg.children():
            if isinstance(layer, nn.MaxPool2d):
                i = 0
                j += 1

            elif isinstance(layer, nn.Conv2d):
                i += 1

            name = f'conv{j}_{i}'

            if name == last_layer:
                model_list.append(layer)
                break

            model_list.append(layer)


        model = nn.Sequential(*model_list)
        return model


    def normalize_vgg(self, image):
        '''
        Expect input in range -1 1
        '''
        image = (image + 1.0) / 2.0
        return (image - self.mean) / self.std


if __name__ == '__main__':
    from PIL import Image
    import numpy as np
    from utils.image_processing import normalize_input

    image = Image.open("example/10.jpg")
    image = image.resize((224, 224))
    np_img = np.array(image).astype('float32')
    np_img = normalize_input(np_img)

    img = torch.from_numpy(np_img)
    img = img.permute(2, 0, 1)
    img = img.unsqueeze(0)

    vgg = Vgg19()

    feat = vgg(img)

    print(feat.shape)