File size: 7,619 Bytes
4135cc9
c1a650d
 
 
 
 
 
 
 
4135cc9
a667744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac18872
c1a650d
 
 
 
 
4135cc9
c1a650d
 
4135cc9
c1a650d
 
 
4135cc9
c1a650d
 
 
4135cc9
c1a650d
4135cc9
c1a650d
 
 
 
 
 
 
4135cc9
c1a650d
 
4135cc9
c1a650d
4135cc9
c1a650d
 
 
4135cc9
c1a650d
 
 
 
 
 
 
 
4135cc9
c1a650d
4135cc9
c1a650d
 
 
4135cc9
c1a650d
4135cc9
 
 
d2b1f3f
 
 
 
fb8a3f7
d2b1f3f
4135cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import gradio as gr
import torch
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import requests
from io import BytesIO

def load_image(img_path, max_size=400, shape=None):
    ''' Load in and transform an image, making sure the image
       is <= 400 pixels in the x-y dims.'''
    if "http" in img_path:
        response = requests.get(img_path)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(img_path).convert('RGB')
    
    # large images will slow down processing
    if max(image.size) > max_size:
        size = max_size
    else:
        size = max(image.size)
    
    if shape is not None:
        size = shape
        
    in_transform = transforms.Compose([
                        transforms.Resize(size),
                        transforms.ToTensor(),
                        transforms.Normalize((0.485, 0.456, 0.406), 
                                             (0.229, 0.224, 0.225))])

    # discard the transparent, alpha channel (that's the :3) and add the batch dimension
    image = in_transform(image)[:3,:,:].unsqueeze(0)
    
    return image

# helper function for un-normalizing an image 
# and converting it from a Tensor image to a NumPy image for display
def im_convert(tensor):
    """ Display a tensor as an image. """
    
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    return image

def get_features(image, model, layers=None):
    """ Run an image forward through a model and get the features for 
        a set of layers. Default layers are for VGGNet matching Gatys et al (2016)
    """
    
    ## TODO: Complete mapping layer names of PyTorch's VGGNet to names from the paper
    ## Need the layers for the content and style representations of an image
    if layers is None:
        layers = {'0': 'conv1_1',
                  '5': 'conv2_1', 
                  '10': 'conv3_1', 
                  '19': 'conv4_1',
                  '21': 'conv4_2',  ## content representation
                  '28': 'conv5_1'}
        
        
    ## -- do not need to change the code below this line -- ##
    features = {}
    x = image
    # model._modules is a dictionary holding each module in the model
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
            
    return features


def gram_matrix(tensor):
    """ Calculate the Gram Matrix of a given tensor 
        Gram Matrix: https://en.wikipedia.org/wiki/Gramian_matrix
    """
    
    ## get the batch_size, depth, height, and width of the Tensor
    ## reshape it, so we're multiplying the features for each channel
    ## calculate the gram matrix
   # get the batch_size, depth, height, and width of the Tensor
    b, d, h, w = tensor.size()
    
    # reshape so we're multiplying the features for each channel
    tensor = tensor.view(b * d, h * w)
    
    # calculate the gram matrix
    gram = torch.mm(tensor, tensor.t())
    
    return gram

#Load VGG19 model
vgg = models.vgg19(pretrained=True).features
for param in vgg.parameters():
    param.requires_grad_(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)

# Helper functions (load_image, im_convert, get_features, gram_matrix)
# ... (Include the helper functions you provided earlier here)

def style_transfer(content_image, style_image, alpha, beta, conv1_1, conv2_1, conv3_1, conv4_1, conv5_1, steps):
    content = load_image(content_image).to(device)
    style = load_image(style_image, shape=content.shape[-2:]).to(device)
    
    content_features = get_features(content, vgg)
    style_features = get_features(style, vgg)
    style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
    
    target = content.clone().requires_grad_(True).to(device)
    
    style_weights = {
        'conv1_1': conv1_1,
        'conv2_1': conv2_1,
        'conv3_1': conv3_1,
        'conv4_1': conv4_1,
        'conv5_1': conv5_1
    }
    
    content_weight = alpha
    style_weight = beta * 1e6
    
    optimizer = optim.Adam([target], lr=0.003)
    
    for ii in range(1, steps+1):
        target_features = get_features(target, vgg)
        content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
        
        style_loss = 0
        for layer in style_weights:
            target_feature = target_features[layer]
            target_gram = gram_matrix(target_feature)
            _, d, h, w = target_feature.shape
            style_gram = style_grams[layer]
            layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
            style_loss += layer_style_loss / (d * h * w)
        
        total_loss = content_weight * content_loss + style_weight * style_loss
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
    
    return im_convert(target)

# Example images
examples = [
    ["https://huggingface.co/spaces/muneebable/vgg-style-transfer/resolve/main/assets/content_1.jpg",
     "https://huggingface.co/spaces/muneebable/vgg-style-transfer/resolve/main/assets/style_1.jpg"],
    ["https://huggingface.co/spaces/muneebable/vgg-style-transfer/resolve/main/assets/content_2.jpg",
     "https://huggingface.co/spaces/muneebable/vgg-style-transfer/resolve/main/assets/style_2.jpg"],
    ["https://huggingface.co/spaces/muneebable/vgg-style-transfer/resolve/main/assets/content_3.png",
     "https://huggingface.co/spaces/muneebable/vgg-style-transfer/resolve/main/assets/style_3.jpg"],
]

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Neural Style Transfer")
    with gr.Row():
        with gr.Column():
            content_input = gr.Image(label="Content Image")
            style_input = gr.Image(label="Style Image")
        with gr.Column():
            output_image = gr.Image(label="Output Image")
    
    with gr.Row():
        alpha_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Content Weight (α)")
        beta_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Style Weight (β)")
    
    with gr.Row():
        conv1_1_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Conv1_1 Weight")
        conv2_1_slider = gr.Slider(minimum=0, maximum=1, value=0.8, step=0.1, label="Conv2_1 Weight")
        conv3_1_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Conv3_1 Weight")
        conv4_1_slider = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="Conv4_1 Weight")
        conv5_1_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Conv5_1 Weight")
    
    steps_slider = gr.Slider(minimum=100, maximum=2000, value=1000, step=100, label="Number of Steps")
    
    run_button = gr.Button("Run Style Transfer")
    
    run_button.click(
        style_transfer,
        inputs=[
            content_input,
            style_input,
            alpha_slider,
            beta_slider,
            conv1_1_slider,
            conv2_1_slider,
            conv3_1_slider,
            conv4_1_slider,
            conv5_1_slider,
            steps_slider
        ],
        outputs=output_image
    )
    
    gr.Examples(
        examples,
        inputs=[content_input, style_input]
    )

demo.launch()