Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,95 @@ import numpy as np
|
|
8 |
import requests
|
9 |
from io import BytesIO
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
#Load VGG19 model
|
12 |
vgg = models.vgg19(pretrained=True).features
|
13 |
for param in vgg.parameters():
|
|
|
8 |
import requests
|
9 |
from io import BytesIO
|
10 |
|
11 |
+
def load_image(img_path, max_size=400, shape=None):
|
12 |
+
''' Load in and transform an image, making sure the image
|
13 |
+
is <= 400 pixels in the x-y dims.'''
|
14 |
+
if "http" in img_path:
|
15 |
+
response = requests.get(img_path)
|
16 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
17 |
+
else:
|
18 |
+
image = Image.open(img_path).convert('RGB')
|
19 |
+
|
20 |
+
# large images will slow down processing
|
21 |
+
if max(image.size) > max_size:
|
22 |
+
size = max_size
|
23 |
+
else:
|
24 |
+
size = max(image.size)
|
25 |
+
|
26 |
+
if shape is not None:
|
27 |
+
size = shape
|
28 |
+
|
29 |
+
in_transform = transforms.Compose([
|
30 |
+
transforms.Resize(size),
|
31 |
+
transforms.ToTensor(),
|
32 |
+
transforms.Normalize((0.485, 0.456, 0.406),
|
33 |
+
(0.229, 0.224, 0.225))])
|
34 |
+
|
35 |
+
# discard the transparent, alpha channel (that's the :3) and add the batch dimension
|
36 |
+
image = in_transform(image)[:3,:,:].unsqueeze(0)
|
37 |
+
|
38 |
+
return image
|
39 |
+
|
40 |
+
# helper function for un-normalizing an image
|
41 |
+
# and converting it from a Tensor image to a NumPy image for display
|
42 |
+
def im_convert(tensor):
|
43 |
+
""" Display a tensor as an image. """
|
44 |
+
|
45 |
+
image = tensor.to("cpu").clone().detach()
|
46 |
+
image = image.numpy().squeeze()
|
47 |
+
image = image.transpose(1,2,0)
|
48 |
+
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
|
49 |
+
image = image.clip(0, 1)
|
50 |
+
|
51 |
+
return image
|
52 |
+
|
53 |
+
def get_features(image, model, layers=None):
|
54 |
+
""" Run an image forward through a model and get the features for
|
55 |
+
a set of layers. Default layers are for VGGNet matching Gatys et al (2016)
|
56 |
+
"""
|
57 |
+
|
58 |
+
## TODO: Complete mapping layer names of PyTorch's VGGNet to names from the paper
|
59 |
+
## Need the layers for the content and style representations of an image
|
60 |
+
if layers is None:
|
61 |
+
layers = {'0': 'conv1_1',
|
62 |
+
'5': 'conv2_1',
|
63 |
+
'10': 'conv3_1',
|
64 |
+
'19': 'conv4_1',
|
65 |
+
'21': 'conv4_2', ## content representation
|
66 |
+
'28': 'conv5_1'}
|
67 |
+
|
68 |
+
|
69 |
+
## -- do not need to change the code below this line -- ##
|
70 |
+
features = {}
|
71 |
+
x = image
|
72 |
+
# model._modules is a dictionary holding each module in the model
|
73 |
+
for name, layer in model._modules.items():
|
74 |
+
x = layer(x)
|
75 |
+
if name in layers:
|
76 |
+
features[layers[name]] = x
|
77 |
+
|
78 |
+
return features
|
79 |
+
|
80 |
+
|
81 |
+
def gram_matrix(tensor):
|
82 |
+
""" Calculate the Gram Matrix of a given tensor
|
83 |
+
Gram Matrix: https://en.wikipedia.org/wiki/Gramian_matrix
|
84 |
+
"""
|
85 |
+
|
86 |
+
## get the batch_size, depth, height, and width of the Tensor
|
87 |
+
## reshape it, so we're multiplying the features for each channel
|
88 |
+
## calculate the gram matrix
|
89 |
+
# get the batch_size, depth, height, and width of the Tensor
|
90 |
+
b, d, h, w = tensor.size()
|
91 |
+
|
92 |
+
# reshape so we're multiplying the features for each channel
|
93 |
+
tensor = tensor.view(b * d, h * w)
|
94 |
+
|
95 |
+
# calculate the gram matrix
|
96 |
+
gram = torch.mm(tensor, tensor.t())
|
97 |
+
|
98 |
+
return gram
|
99 |
+
|
100 |
#Load VGG19 model
|
101 |
vgg = models.vgg19(pretrained=True).features
|
102 |
for param in vgg.parameters():
|