Spaces:
Sleeping
Sleeping
File size: 8,886 Bytes
4135cc9 c1a650d 4135cc9 ff37045 a667744 ff37045 a667744 aee71f9 c1a650d 4135cc9 c1a650d 4135cc9 c1a650d 4135cc9 c1a650d 4135cc9 c1a650d 4135cc9 c1a650d 4135cc9 aee71f9 c1a650d 4135cc9 c1a650d 4135cc9 c1a650d 4135cc9 c1a650d aee71f9 4135cc9 aee71f9 6f34bae 55cee48 6f34bae 4135cc9 445e396 4135cc9 55cee48 f89919c ff37045 6f34bae 55cee48 6f34bae 4135cc9 29df5af 4135cc9 5aacb28 4135cc9 db50207 4135cc9 aee71f9 4135cc9 6f34bae 55cee48 4135cc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 |
import gradio as gr
import torch
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import requests
from io import BytesIO
def load_image(img_input, max_size=400, shape=None):
if isinstance(img_input, np.ndarray):
# Convert numpy array to PIL Image
image = Image.fromarray(img_input.astype('uint8'), 'RGB')
elif isinstance(img_input, str):
if "http" in img_input:
response = requests.get(img_input)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(img_input).convert('RGB')
else:
raise ValueError("Unsupported input type. Expected numpy array or string.")
# large images will slow down processing
if max(image.size) > max_size:
size = max_size
else:
size = max(image.size)
if shape is not None:
size = shape
in_transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
# discard the transparent, alpha channel (that's the :3) and add the batch dimension
image = in_transform(image)[:3,:,:].unsqueeze(0)
return image
# helper function for un-normalizing an image
# and converting it from a Tensor image to a NumPy image for display
def im_convert(tensor):
""" Display a tensor as an image. """
image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1,2,0)
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1)
return image
def get_features(image, model, layers=None):
""" Run an image forward through a model and get the features for
a set of layers. Default layers are for VGGNet matching Gatys et al (2016)
"""
## TODO: Complete mapping layer names of PyTorch's VGGNet to names from the paper
## Need the layers for the content and style representations of an image
if layers is None:
layers = {'0': 'conv1_1',
'5': 'conv2_1',
'10': 'conv3_1',
'19': 'conv4_1',
'21': 'conv4_2', ## content representation
'28': 'conv5_1'}
## -- do not need to change the code below this line -- ##
features = {}
x = image
# model._modules is a dictionary holding each module in the model
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
def gram_matrix(tensor):
""" Calculate the Gram Matrix of a given tensor
Gram Matrix: https://en.wikipedia.org/wiki/Gramian_matrix
"""
## get the batch_size, depth, height, and width of the Tensor
## reshape it, so we're multiplying the features for each channel
## calculate the gram matrix
# get the batch_size, depth, height, and width of the Tensor
b, d, h, w = tensor.size()
# reshape so we're multiplying the features for each channel
tensor = tensor.view(b * d, h * w)
# calculate the gram matrix
gram = torch.mm(tensor, tensor.t())
return gram
# Function to resize image while maintaining aspect ratio
def resize_image(image_path, max_size=400):
img = Image.open(image_path).convert('RGB')
ratio = max_size / max(img.size)
new_size = tuple([int(x*ratio) for x in img.size])
img = img.resize(new_size, Image.Resampling.LANCZOS)
return np.array(img)
def create_grid(images, rows, cols):
assert len(images) == rows * cols, "Number of images doesn't match the grid size"
w, h = images[0].shape[1], images[0].shape[0]
grid = np.zeros((h*rows, w*cols, 3), dtype=np.uint8)
for i, img in enumerate(images):
r, c = divmod(i, cols)
grid[r*h:(r+1)*h, c*w:(c+1)*w] = img
return grid
def style_transfer(content_image, style_image, alpha, beta, conv1_1, conv2_1, conv3_1, conv4_1, conv5_1, steps):
content = load_image(content_image).to(device)
style = load_image(style_image, shape=content.shape[-2:]).to(device)
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
target = content.clone().requires_grad_(True).to(device)
style_weights = {
'conv1_1': conv1_1,
'conv2_1': conv2_1,
'conv3_1': conv3_1,
'conv4_1': conv4_1,
'conv5_1': conv5_1
}
content_weight = alpha
style_weight = beta * 1e6
optimizer = optim.Adam([target], lr=0.003)
intermediate_images = []
show_every = steps // 9 # Show 9 intermediate images
for ii in range(1, steps+1):
target_features = get_features(target, vgg)
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
style_loss = 0
for layer in style_weights:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
_, d, h, w = target_feature.shape
style_gram = style_grams[layer]
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
style_loss += layer_style_loss / (d * h * w)
total_loss = content_weight * content_loss + style_weight * style_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if ii % show_every == 0 or ii == steps:
intermediate_images.append(im_convert(target))
final_image = intermediate_images[-1]
intermediate_grid = create_grid(intermediate_images, 3, 3)
return final_image, intermediate_grid
def load_example(content, style, output):
return content, style, output
# Example images
examples = [
["assets/content_1.jpg", "assets/style_1.jpg", "assets/result_1.png"],
["assets/content_2.jpg", "assets/style_2.jpg", "assets/result_2.png"],
["assets/content_3.png", "assets/style_3.jpg", "assets/result_3.png"],
]
#Load VGG19 model
vgg = models.vgg19(pretrained=True).features
for param in vgg.parameters():
param.requires_grad_(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)
# Resize example images
resized_examples = [[resize_image(content), resize_image(style), resize_image(output)] for content, style, output in examples]
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Neural Style Transfer")
with gr.Row():
with gr.Column():
content_input = gr.Image(label="Content Image", type="numpy", image_mode="RGB", height=400, width=400)
style_input = gr.Image(label="Style Image", type="numpy", image_mode="RGB", height=400, width=400)
with gr.Column():
output_image = gr.Image(label="Output Image")
intermediate_output = gr.Image(label="Intermediate Results")
with gr.Row():
alpha_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Content Weight (α)")
beta_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Style Weight (β)")
with gr.Row():
conv1_1_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Conv1_1 Weight")
conv2_1_slider = gr.Slider(minimum=0, maximum=1, value=0.8, step=0.1, label="Conv2_1 Weight")
conv3_1_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Conv3_1 Weight")
conv4_1_slider = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="Conv4_1 Weight")
conv5_1_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Conv5_1 Weight")
steps_slider = gr.Slider(minimum=1, maximum=2000, value=1000, step=100, label="Number of Steps")
run_button = gr.Button("Run Style Transfer")
run_button.click(
style_transfer,
inputs=[
content_input,
style_input,
alpha_slider,
beta_slider,
conv1_1_slider,
conv2_1_slider,
conv3_1_slider,
conv4_1_slider,
conv5_1_slider,
steps_slider
],
outputs=[output_image, intermediate_output]
)
gr.Examples(
resized_examples,
inputs=[content_input, style_input, output_image],
outputs=[content_input, style_input, output_image],
fn=load_example,
cache_examples=True
)
demo.launch() |