|
import gradio as gr |
|
import torch |
|
import torch.optim as optim |
|
import torchvision.models as models |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
import numpy as np |
|
import requests |
|
from io import BytesIO |
|
|
|
def load_image(img_input, max_size=400, shape=None): |
|
if isinstance(img_input, np.ndarray): |
|
|
|
image = Image.fromarray(img_input.astype('uint8'), 'RGB') |
|
elif isinstance(img_input, str): |
|
if "http" in img_input: |
|
response = requests.get(img_input) |
|
image = Image.open(BytesIO(response.content)).convert('RGB') |
|
else: |
|
image = Image.open(img_input).convert('RGB') |
|
else: |
|
raise ValueError("Unsupported input type. Expected numpy array or string.") |
|
|
|
|
|
if max(image.size) > max_size: |
|
size = max_size |
|
else: |
|
size = max(image.size) |
|
|
|
if shape is not None: |
|
size = shape |
|
|
|
in_transform = transforms.Compose([ |
|
transforms.Resize(size), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.485, 0.456, 0.406), |
|
(0.229, 0.224, 0.225))]) |
|
|
|
|
|
image = in_transform(image)[:3,:,:].unsqueeze(0) |
|
|
|
return image |
|
|
|
|
|
|
|
def im_convert(tensor): |
|
""" Display a tensor as an image. """ |
|
|
|
image = tensor.to("cpu").clone().detach() |
|
image = image.numpy().squeeze() |
|
image = image.transpose(1,2,0) |
|
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)) |
|
image = image.clip(0, 1) |
|
|
|
return image |
|
|
|
def get_features(image, model, layers=None): |
|
""" Run an image forward through a model and get the features for |
|
a set of layers. Default layers are for VGGNet matching Gatys et al (2016) |
|
""" |
|
|
|
|
|
|
|
if layers is None: |
|
layers = {'0': 'conv1_1', |
|
'5': 'conv2_1', |
|
'10': 'conv3_1', |
|
'19': 'conv4_1', |
|
'21': 'conv4_2', |
|
'28': 'conv5_1'} |
|
|
|
|
|
|
|
features = {} |
|
x = image |
|
|
|
for name, layer in model._modules.items(): |
|
x = layer(x) |
|
if name in layers: |
|
features[layers[name]] = x |
|
|
|
return features |
|
|
|
|
|
def gram_matrix(tensor): |
|
""" Calculate the Gram Matrix of a given tensor |
|
Gram Matrix: https://en.wikipedia.org/wiki/Gramian_matrix |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
b, d, h, w = tensor.size() |
|
|
|
|
|
tensor = tensor.view(b * d, h * w) |
|
|
|
|
|
gram = torch.mm(tensor, tensor.t()) |
|
|
|
return gram |
|
|
|
|
|
def resize_image(image_path, max_size=400): |
|
img = Image.open(image_path).convert('RGB') |
|
ratio = max_size / max(img.size) |
|
new_size = tuple([int(x*ratio) for x in img.size]) |
|
img = img.resize(new_size, Image.Resampling.LANCZOS) |
|
return np.array(img) |
|
|
|
def create_grid(images, rows, cols): |
|
assert len(images) == rows * cols, "Number of images doesn't match the grid size" |
|
w, h = images[0].shape[1], images[0].shape[0] |
|
grid = np.zeros((h*rows, w*cols, 3), dtype=np.uint8) |
|
for i, img in enumerate(images): |
|
r, c = divmod(i, cols) |
|
grid[r*h:(r+1)*h, c*w:(c+1)*w] = img |
|
return grid |
|
|
|
def style_transfer(content_image, style_image, alpha, beta, conv1_1, conv2_1, conv3_1, conv4_1, conv5_1, steps): |
|
content = load_image(content_image).to(device) |
|
style = load_image(style_image, shape=content.shape[-2:]).to(device) |
|
|
|
content_features = get_features(content, vgg) |
|
style_features = get_features(style, vgg) |
|
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features} |
|
|
|
target = content.clone().requires_grad_(True).to(device) |
|
|
|
style_weights = { |
|
'conv1_1': conv1_1, |
|
'conv2_1': conv2_1, |
|
'conv3_1': conv3_1, |
|
'conv4_1': conv4_1, |
|
'conv5_1': conv5_1 |
|
} |
|
|
|
content_weight = alpha |
|
style_weight = beta * 1e6 |
|
|
|
optimizer = optim.Adam([target], lr=0.003) |
|
|
|
intermediate_images = [] |
|
show_every = steps // 9 |
|
|
|
for ii in range(1, steps+1): |
|
target_features = get_features(target, vgg) |
|
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2) |
|
|
|
style_loss = 0 |
|
for layer in style_weights: |
|
target_feature = target_features[layer] |
|
target_gram = gram_matrix(target_feature) |
|
_, d, h, w = target_feature.shape |
|
style_gram = style_grams[layer] |
|
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2) |
|
style_loss += layer_style_loss / (d * h * w) |
|
|
|
total_loss = content_weight * content_loss + style_weight * style_loss |
|
|
|
optimizer.zero_grad() |
|
total_loss.backward() |
|
optimizer.step() |
|
|
|
if ii % show_every == 0 or ii == steps: |
|
intermediate_images.append(im_convert(target)) |
|
|
|
final_image = intermediate_images[-1] |
|
intermediate_grid = create_grid(intermediate_images, 3, 3) |
|
|
|
return final_image, intermediate_grid |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples = [ |
|
["assets/content_1.jpg", |
|
"assets/style_1.jpg"], |
|
["assets/content_2.jpg", |
|
"assets/style_2.jpg"], |
|
["assets/content_3.png", |
|
"assets/style_3.jpg"], |
|
] |
|
|
|
|
|
vgg = models.vgg19(pretrained=True).features |
|
for param in vgg.parameters(): |
|
param.requires_grad_(False) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
vgg.to(device) |
|
|
|
|
|
|
|
resized_examples = [[resize_image(content), resize_image(style)] for content, style in examples] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Neural Style Transfer") |
|
with gr.Row(): |
|
with gr.Column(): |
|
content_input = gr.Image(label="Content Image", type="numpy", image_mode="RGB", height=400, width=400) |
|
style_input = gr.Image(label="Style Image", type="numpy", image_mode="RGB", height=400, width=400) |
|
with gr.Column(): |
|
output_image = gr.Image(label="Output Image") |
|
|
|
with gr.Row(): |
|
alpha_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Content Weight (α)") |
|
beta_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Style Weight (β)") |
|
|
|
with gr.Row(): |
|
conv1_1_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Conv1_1 Weight") |
|
conv2_1_slider = gr.Slider(minimum=0, maximum=1, value=0.8, step=0.1, label="Conv2_1 Weight") |
|
conv3_1_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Conv3_1 Weight") |
|
conv4_1_slider = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="Conv4_1 Weight") |
|
conv5_1_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Conv5_1 Weight") |
|
|
|
steps_slider = gr.Slider(minimum=1, maximum=2000, value=1000, step=100, label="Number of Steps") |
|
|
|
run_button = gr.Button("Run Style Transfer") |
|
intermediate_output = gr.Image(label="Intermediate Results") |
|
|
|
run_button.click( |
|
style_transfer, |
|
inputs=[ |
|
content_input, |
|
style_input, |
|
alpha_slider, |
|
beta_slider, |
|
conv1_1_slider, |
|
conv2_1_slider, |
|
conv3_1_slider, |
|
conv4_1_slider, |
|
conv5_1_slider, |
|
steps_slider |
|
], |
|
outputs=[output_image, intermediate_output] |
|
) |
|
|
|
gr.Examples( |
|
resized_examples, |
|
inputs=[content_input, style_input] |
|
) |
|
|
|
demo.launch() |