Francesco Pochetti commited on
Commit
7a4e8ef
·
1 Parent(s): 6cc17a8

adding files

Browse files
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Fast Neural Style Transfer
3
- emoji: 📊
4
- colorFrom: red
5
- colorTo: pink
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
 
1
  ---
2
  title: Fast Neural Style Transfer
3
+ emoji: 🎨
4
+ colorFrom: green
5
+ colorTo: red
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydoc import describe
2
+ import re
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from torchvision import transforms
7
+ import gradio as gr
8
+ from model import TransformerNet
9
+
10
+ style_model = TransformerNet()
11
+ device=torch.device("cpu")
12
+
13
+ styles_map = {"Kandinsky, Several circles": "kand_circles.model",
14
+ "Haring, Dance": "haring_dance.model",
15
+ "Picasso, The weeping woman": "picasso_weeping.model",
16
+ "Van Gogh, Wheatfield with crows": "vangogh_crows.model"}
17
+
18
+ content_transform = transforms.Compose([
19
+ transforms.ToTensor(),
20
+ transforms.Lambda(lambda x: x.mul(255))
21
+ ])
22
+
23
+ def run(content_image, style):
24
+ content_image.thumbnail((1080, 1080))
25
+ img = content_transform(content_image)
26
+ img = img.unsqueeze(0).to(device)
27
+
28
+ model = styles_map[style]
29
+ state_dict = torch.load(f"./models/{model}")
30
+
31
+ for k in list(state_dict.keys()):
32
+ if re.search(r'in\d+\.running_(mean|var)$', k):
33
+ del state_dict[k]
34
+ style_model.load_state_dict(state_dict)
35
+ style_model.to(device)
36
+
37
+ with torch.no_grad():
38
+ output = style_model(img)
39
+
40
+ img = output[0].clone().clamp(0, 255).numpy()
41
+ img = img.transpose(1, 2, 0).astype("uint8")
42
+
43
+ img = Image.fromarray(img)
44
+ return img
45
+
46
+ content_image_input = gr.inputs.Image(label="Content Image", type="pil")
47
+ style_input = gr.inputs.Dropdown(list(styles_map.keys()), type="value", default="Kandinsky, Several circles", label="Style")
48
+
49
+ description="Fast Neural Style Transfer demo (trained from scratch!). Upload a content image. Select an artwork. Enjoy."
50
+ article="""
51
+ **References**\n\n
52
+ You can find <a href='https://francescopochetti.com/fast-neural-style-transfer-deploying-pytorch-models-to-aws-lambda/' target='_blank'>here</a> a post I put together
53
+ describing the approach I used to train models and deploy them on <a href='http://visualneurons.com/fast.html' target='_blank'>visualneurons.com</a> using AWS Lambda. \n
54
+ <a href='https://github.com/FraPochetti/examples/blob/master/fast_neural_style/neural_style/FastStyleTransferPytorch.ipynb' target='_blank'>Here</a> is instead the Jupyter notebook
55
+ with the training logic. \n
56
+ <br>
57
+ <hr>
58
+ **Kandinsky, Several circles**
59
+ <img src='https://style-transfer-webapptest.s3.eu-west-1.amazonaws.com/small_images_hf/Several_Circles.jpeg'>
60
+ <hr>
61
+ **Haring, Dance**
62
+ <img src='https://style-transfer-webapptest.s3.eu-west-1.amazonaws.com/small_images_hf/Haring.jpeg'>
63
+ <hr>
64
+ **Picasso, The weeping woman**
65
+ <img src='https://style-transfer-webapptest.s3.eu-west-1.amazonaws.com/small_images_hf/weeping.png'>
66
+ <hr>
67
+ **Van Gogh, Wheatfield with crows**
68
+ <img src='https://style-transfer-webapptest.s3.eu-west-1.amazonaws.com/small_images_hf/Wheatfield_with_crows.jpeg'>
69
+ """
70
+ example = ["dog.jpeg", "Kandinsky, Several circles"]
71
+
72
+ app_interface = gr.Interface(fn=run,
73
+ inputs=[content_image_input, style_input],
74
+ outputs="image",
75
+ title="Fast Neural Style Transfer",
76
+ description=description,
77
+ examples=[example],
78
+ article=article)
79
+ app_interface.launch()
dog.jpeg ADDED
model.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class TransformerNet(torch.nn.Module):
4
+ def __init__(self):
5
+ super(TransformerNet, self).__init__()
6
+ # Initial convolution layers
7
+ self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
8
+ self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
9
+ self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
10
+ self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
11
+ self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
12
+ self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
13
+ # Residual layers
14
+ self.res1 = ResidualBlock(128)
15
+ self.res2 = ResidualBlock(128)
16
+ self.res3 = ResidualBlock(128)
17
+ self.res4 = ResidualBlock(128)
18
+ self.res5 = ResidualBlock(128)
19
+ # Upsampling Layers
20
+ self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
21
+ self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
22
+ self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
23
+ self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
24
+ self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
25
+ # Non-linearities
26
+ self.relu = torch.nn.ReLU()
27
+
28
+ def forward(self, X):
29
+ y = self.relu(self.in1(self.conv1(X)))
30
+ y = self.relu(self.in2(self.conv2(y)))
31
+ y = self.relu(self.in3(self.conv3(y)))
32
+ y = self.res1(y)
33
+ y = self.res2(y)
34
+ y = self.res3(y)
35
+ y = self.res4(y)
36
+ y = self.res5(y)
37
+ y = self.relu(self.in4(self.deconv1(y)))
38
+ y = self.relu(self.in5(self.deconv2(y)))
39
+ y = self.deconv3(y)
40
+ return y
41
+
42
+
43
+ class ConvLayer(torch.nn.Module):
44
+ def __init__(self, in_channels, out_channels, kernel_size, stride):
45
+ super(ConvLayer, self).__init__()
46
+ reflection_padding = kernel_size // 2
47
+ self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
48
+ self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
49
+
50
+ def forward(self, x):
51
+ out = self.reflection_pad(x)
52
+ out = self.conv2d(out)
53
+ return out
54
+
55
+
56
+ class ResidualBlock(torch.nn.Module):
57
+ """ResidualBlock
58
+ introduced in: https://arxiv.org/abs/1512.03385
59
+ recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
60
+ """
61
+
62
+ def __init__(self, channels):
63
+ super(ResidualBlock, self).__init__()
64
+ self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
65
+ self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
66
+ self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
67
+ self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
68
+ self.relu = torch.nn.ReLU()
69
+
70
+ def forward(self, x):
71
+ residual = x
72
+ out = self.relu(self.in1(self.conv1(x)))
73
+ out = self.in2(self.conv2(out))
74
+ out = out + residual
75
+ return out
76
+
77
+
78
+ class UpsampleConvLayer(torch.nn.Module):
79
+ """UpsampleConvLayer
80
+ Upsamples the input and then does a convolution. This method gives better results
81
+ compared to ConvTranspose2d.
82
+ ref: http://distill.pub/2016/deconv-checkerboard/
83
+ """
84
+
85
+ def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
86
+ super(UpsampleConvLayer, self).__init__()
87
+ self.upsample = upsample
88
+ reflection_padding = kernel_size // 2
89
+ self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
90
+ self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
91
+
92
+ def forward(self, x):
93
+ x_in = x
94
+ if self.upsample:
95
+ x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
96
+ out = self.reflection_pad(x_in)
97
+ out = self.conv2d(out)
98
+ return out
models/haring_dance.model ADDED
Binary file (6.73 MB). View file
 
models/kand_circles.model ADDED
Binary file (6.73 MB). View file
 
models/picasso_weeping.model ADDED
Binary file (6.73 MB). View file
 
models/vangogh_crows.model ADDED
Binary file (6.73 MB). View file
 
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pillow > 8.0.0,<9
2
+ torchvision==0.10.0
3
+ torch===1.9.0