File size: 3,826 Bytes
b947961
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f0eb84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
import numpy as np
import gradio as gr

IMAGE_SIZE = 244 # VGG image input size - we use VGG 19 as our pretrained CNN

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

cnn = models.vgg19(weights=None)
state_dict = torch.load("vgg19-dcbb9e9d.pth")
cnn.load_state_dict(state_dict)

class VGG(nn.Module):
  def __init__(self):
    super(VGG, self).__init__()
    self.layers = ['0', '5', '10', '19', '28'] # layers we use as representations
    self.model = cnn.features[:29] # we don't care about later layers

  def forward(self, x):
    features = []

    for layer_num, layer in enumerate(self.model):
      x = layer(x)

      # we don't care about the model output - we care about the output of individual layers
      if str(layer_num) in self.layers:
        features.append(x)

    return features
  

gradio_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize([IMAGE_SIZE, IMAGE_SIZE])
])


def sanitize_inputs(epochs, lr, cl, sl):
  if epochs < 1:
    return ["Epochs must be positive", None]
  if not isinstance(epochs, int):
    return ["Epochs must be an integer", None]
  if lr < 0:
    return ["Learning rate must be positive", None]
  if lr > 1:
    return ["Learning rate must be less than one", None]
  if cl < 0 or cl > 1:
    return ["Content loss weight must be between 0 and 1", None]
  if sl < 0 or sl > 1:
    return ["Style loss weight must be between 0 and 1", None]

  return None

def train(Epochs, Learning_Rate, Content_Loss, Style_Loss, Content_Image, Style_Image):
  errors = sanitize_inputs(Epochs, Learning_Rate, Content_Loss, Style_Loss)
  if errors is not None:
    return errors

  test = Content_Image

  content = gradio_transforms(Content_Image).unsqueeze(0).to(device)
  style = gradio_transforms(Style_Image).unsqueeze(0).to(device)
  generated = content.clone().requires_grad_(True).to(device)

  model = VGG().to(device).eval()
  optimizer = optim.Adam([generated], lr=Learning_Rate)

  for epoch in range(Epochs):
    generatedFeatures = model(generated)
    contentFeatures = model(content)
    styleFeatures = model(style)

    styleLoss = 0
    contentLoss = 0

    for genFeat, contFeat, styleFeat in zip(generatedFeatures, contentFeatures, styleFeatures):
      batch_size, channel, height, width = genFeat.shape

      contentLoss += torch.mean((genFeat - contFeat) ** 2)

      G = genFeat.view(channel, height * width).mm(genFeat.view(channel, height * width).t())
      A = styleFeat.view(channel, height * width).mm(styleFeat.view(channel, height * width).t())

      styleLoss += torch.mean((G - A) ** 2)

    total_loss = Content_Loss * contentLoss + Style_Loss * styleLoss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()


  save_image(generated, "generated_gradio.png")

  return ["No errors! Enjoy your new image!", "generated_gradio.png"]

demo = gr.Interface(
    fn=train,
    inputs=["number", "number", "number", "number", "image", "image"],
    outputs=[
        gr.Label(label="Error Messages"),
        gr.Image(label="Generated Image"),
    ],
    title="Neural Style Transfer",
    description="Perform neural style transfer on images of your choice!  Provide a content image that contains the content you want to transform and a style image that contains the style you want to emulate.\n\nNote: Huggingface requires users to pay to gain access to GPUs, so this model is hosted on a cpu. Training for many epochs will take a VERY long time.  Using a larger learning rate (e.g., 0.01) can help reduce the number of epochs you need.",
    theme=gr.themes.Soft()
)

demo.launch(debug=True)