File size: 8,886 Bytes
4135cc9
c1a650d
 
 
 
 
 
 
 
4135cc9
ff37045
 
 
 
 
 
 
 
 
 
a667744
ff37045
a667744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aee71f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1a650d
 
 
4135cc9
c1a650d
 
 
4135cc9
c1a650d
4135cc9
c1a650d
 
 
 
 
 
 
4135cc9
c1a650d
 
4135cc9
c1a650d
4135cc9
aee71f9
 
 
c1a650d
 
 
4135cc9
c1a650d
 
 
 
 
 
 
 
4135cc9
c1a650d
4135cc9
c1a650d
 
 
aee71f9
 
 
4135cc9
aee71f9
 
 
 
6f34bae
55cee48
 
6f34bae
4135cc9
445e396
4135cc9
55cee48
 
 
f89919c
 
ff37045
 
 
 
 
 
 
 
6f34bae
55cee48
6f34bae
4135cc9
 
 
 
 
29df5af
 
4135cc9
 
5aacb28
4135cc9
 
 
 
 
 
 
 
 
 
 
 
db50207
4135cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aee71f9
4135cc9
 
 
6f34bae
55cee48
 
 
 
4135cc9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import gradio as gr
import torch
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import requests
from io import BytesIO

def load_image(img_input, max_size=400, shape=None):
    if isinstance(img_input, np.ndarray):
        # Convert numpy array to PIL Image
        image = Image.fromarray(img_input.astype('uint8'), 'RGB')
    elif isinstance(img_input, str):
        if "http" in img_input:
            response = requests.get(img_input)
            image = Image.open(BytesIO(response.content)).convert('RGB')
        else:
            image = Image.open(img_input).convert('RGB')
    else:
        raise ValueError("Unsupported input type. Expected numpy array or string.")
    
    # large images will slow down processing
    if max(image.size) > max_size:
        size = max_size
    else:
        size = max(image.size)
    
    if shape is not None:
        size = shape
        
    in_transform = transforms.Compose([
                        transforms.Resize(size),
                        transforms.ToTensor(),
                        transforms.Normalize((0.485, 0.456, 0.406), 
                                             (0.229, 0.224, 0.225))])

    # discard the transparent, alpha channel (that's the :3) and add the batch dimension
    image = in_transform(image)[:3,:,:].unsqueeze(0)
    
    return image

# helper function for un-normalizing an image 
# and converting it from a Tensor image to a NumPy image for display
def im_convert(tensor):
    """ Display a tensor as an image. """
    
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    return image

def get_features(image, model, layers=None):
    """ Run an image forward through a model and get the features for 
        a set of layers. Default layers are for VGGNet matching Gatys et al (2016)
    """
    
    ## TODO: Complete mapping layer names of PyTorch's VGGNet to names from the paper
    ## Need the layers for the content and style representations of an image
    if layers is None:
        layers = {'0': 'conv1_1',
                  '5': 'conv2_1', 
                  '10': 'conv3_1', 
                  '19': 'conv4_1',
                  '21': 'conv4_2',  ## content representation
                  '28': 'conv5_1'}
        
        
    ## -- do not need to change the code below this line -- ##
    features = {}
    x = image
    # model._modules is a dictionary holding each module in the model
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
            
    return features


def gram_matrix(tensor):
    """ Calculate the Gram Matrix of a given tensor 
        Gram Matrix: https://en.wikipedia.org/wiki/Gramian_matrix
    """
    
    ## get the batch_size, depth, height, and width of the Tensor
    ## reshape it, so we're multiplying the features for each channel
    ## calculate the gram matrix
   # get the batch_size, depth, height, and width of the Tensor
    b, d, h, w = tensor.size()
    
    # reshape so we're multiplying the features for each channel
    tensor = tensor.view(b * d, h * w)
    
    # calculate the gram matrix
    gram = torch.mm(tensor, tensor.t())
    
    return gram

# Function to resize image while maintaining aspect ratio
def resize_image(image_path, max_size=400):
    img = Image.open(image_path).convert('RGB')
    ratio = max_size / max(img.size)
    new_size = tuple([int(x*ratio) for x in img.size])
    img = img.resize(new_size, Image.Resampling.LANCZOS)
    return np.array(img)

def create_grid(images, rows, cols):
    assert len(images) == rows * cols, "Number of images doesn't match the grid size"
    w, h = images[0].shape[1], images[0].shape[0]
    grid = np.zeros((h*rows, w*cols, 3), dtype=np.uint8)
    for i, img in enumerate(images):
        r, c = divmod(i, cols)
        grid[r*h:(r+1)*h, c*w:(c+1)*w] = img
    return grid

def style_transfer(content_image, style_image, alpha, beta, conv1_1, conv2_1, conv3_1, conv4_1, conv5_1, steps):
    content = load_image(content_image).to(device)
    style = load_image(style_image, shape=content.shape[-2:]).to(device)
    
    content_features = get_features(content, vgg)
    style_features = get_features(style, vgg)
    style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
    
    target = content.clone().requires_grad_(True).to(device)
    
    style_weights = {
        'conv1_1': conv1_1,
        'conv2_1': conv2_1,
        'conv3_1': conv3_1,
        'conv4_1': conv4_1,
        'conv5_1': conv5_1
    }
    
    content_weight = alpha
    style_weight = beta * 1e6
    
    optimizer = optim.Adam([target], lr=0.003)
    
    intermediate_images = []
    show_every = steps // 9  # Show 9 intermediate images

    for ii in range(1, steps+1):
        target_features = get_features(target, vgg)
        content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
        
        style_loss = 0
        for layer in style_weights:
            target_feature = target_features[layer]
            target_gram = gram_matrix(target_feature)
            _, d, h, w = target_feature.shape
            style_gram = style_grams[layer]
            layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
            style_loss += layer_style_loss / (d * h * w)
        
        total_loss = content_weight * content_loss + style_weight * style_loss
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        if ii % show_every == 0 or ii == steps:
            intermediate_images.append(im_convert(target))
    
    final_image = intermediate_images[-1]
    intermediate_grid = create_grid(intermediate_images, 3, 3)
    
    return final_image, intermediate_grid

def load_example(content, style, output):
        return content, style, output

# Example images

examples = [
    ["assets/content_1.jpg", "assets/style_1.jpg", "assets/result_1.png"],
    ["assets/content_2.jpg", "assets/style_2.jpg", "assets/result_2.png"],
    ["assets/content_3.png", "assets/style_3.jpg", "assets/result_3.png"],
]

#Load VGG19 model
vgg = models.vgg19(pretrained=True).features
for param in vgg.parameters():
    param.requires_grad_(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)


# Resize example images
resized_examples = [[resize_image(content), resize_image(style), resize_image(output)] for content, style, output in examples]

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Neural Style Transfer")
    with gr.Row():
        with gr.Column():
            content_input = gr.Image(label="Content Image", type="numpy", image_mode="RGB", height=400, width=400)
            style_input = gr.Image(label="Style Image", type="numpy", image_mode="RGB", height=400, width=400)
        with gr.Column():
            output_image = gr.Image(label="Output Image")
            intermediate_output = gr.Image(label="Intermediate Results")
    
    with gr.Row():
        alpha_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Content Weight (α)")
        beta_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Style Weight (β)")
    
    with gr.Row():
        conv1_1_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Conv1_1 Weight")
        conv2_1_slider = gr.Slider(minimum=0, maximum=1, value=0.8, step=0.1, label="Conv2_1 Weight")
        conv3_1_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Conv3_1 Weight")
        conv4_1_slider = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="Conv4_1 Weight")
        conv5_1_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Conv5_1 Weight")
    
    steps_slider = gr.Slider(minimum=1, maximum=2000, value=1000, step=100, label="Number of Steps")
    
    run_button = gr.Button("Run Style Transfer")
    
    run_button.click(
        style_transfer,
        inputs=[
            content_input,
            style_input,
            alpha_slider,
            beta_slider,
            conv1_1_slider,
            conv2_1_slider,
            conv3_1_slider,
            conv4_1_slider,
            conv5_1_slider,
            steps_slider
        ],
        outputs=[output_image, intermediate_output]
    )
    
    gr.Examples(
        resized_examples,
        inputs=[content_input, style_input, output_image],
        outputs=[content_input, style_input, output_image],
        fn=load_example,
        cache_examples=True
    )

demo.launch()