Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.optim as optim | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import numpy as np | |
import requests | |
from io import BytesIO | |
def load_image(img_input, max_size=400, shape=None): | |
if isinstance(img_input, np.ndarray): | |
# Convert numpy array to PIL Image | |
image = Image.fromarray(img_input.astype('uint8'), 'RGB') | |
elif isinstance(img_input, str): | |
if "http" in img_input: | |
response = requests.get(img_input) | |
image = Image.open(BytesIO(response.content)).convert('RGB') | |
else: | |
image = Image.open(img_input).convert('RGB') | |
else: | |
raise ValueError("Unsupported input type. Expected numpy array or string.") | |
# large images will slow down processing | |
if max(image.size) > max_size: | |
size = max_size | |
else: | |
size = max(image.size) | |
if shape is not None: | |
size = shape | |
in_transform = transforms.Compose([ | |
transforms.Resize(size), | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), | |
(0.229, 0.224, 0.225))]) | |
# discard the transparent, alpha channel (that's the :3) and add the batch dimension | |
image = in_transform(image)[:3,:,:].unsqueeze(0) | |
return image | |
# helper function for un-normalizing an image | |
# and converting it from a Tensor image to a NumPy image for display | |
def im_convert(tensor): | |
""" Display a tensor as an image. """ | |
image = tensor.to("cpu").clone().detach() | |
image = image.numpy().squeeze() | |
image = image.transpose(1,2,0) | |
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)) | |
image = image.clip(0, 1) | |
return image | |
def get_features(image, model, layers=None): | |
""" Run an image forward through a model and get the features for | |
a set of layers. Default layers are for VGGNet matching Gatys et al (2016) | |
""" | |
## TODO: Complete mapping layer names of PyTorch's VGGNet to names from the paper | |
## Need the layers for the content and style representations of an image | |
if layers is None: | |
layers = {'0': 'conv1_1', | |
'5': 'conv2_1', | |
'10': 'conv3_1', | |
'19': 'conv4_1', | |
'21': 'conv4_2', ## content representation | |
'28': 'conv5_1'} | |
## -- do not need to change the code below this line -- ## | |
features = {} | |
x = image | |
# model._modules is a dictionary holding each module in the model | |
for name, layer in model._modules.items(): | |
x = layer(x) | |
if name in layers: | |
features[layers[name]] = x | |
return features | |
def gram_matrix(tensor): | |
""" Calculate the Gram Matrix of a given tensor | |
Gram Matrix: https://en.wikipedia.org/wiki/Gramian_matrix | |
""" | |
## get the batch_size, depth, height, and width of the Tensor | |
## reshape it, so we're multiplying the features for each channel | |
## calculate the gram matrix | |
# get the batch_size, depth, height, and width of the Tensor | |
b, d, h, w = tensor.size() | |
# reshape so we're multiplying the features for each channel | |
tensor = tensor.view(b * d, h * w) | |
# calculate the gram matrix | |
gram = torch.mm(tensor, tensor.t()) | |
return gram | |
# Function to resize image while maintaining aspect ratio | |
def resize_image(image_path, max_size=400): | |
img = Image.open(image_path).convert('RGB') | |
ratio = max_size / max(img.size) | |
new_size = tuple([int(x*ratio) for x in img.size]) | |
img = img.resize(new_size, Image.Resampling.LANCZOS) | |
return np.array(img) | |
def create_grid(images, rows, cols): | |
assert len(images) == rows * cols, "Number of images doesn't match the grid size" | |
w, h = images[0].shape[1], images[0].shape[0] | |
grid = np.zeros((h*rows, w*cols, 3), dtype=np.uint8) | |
for i, img in enumerate(images): | |
r, c = divmod(i, cols) | |
grid[r*h:(r+1)*h, c*w:(c+1)*w] = img | |
return grid | |
def style_transfer(content_image, style_image, alpha, beta, conv1_1, conv2_1, conv3_1, conv4_1, conv5_1, steps): | |
content = load_image(content_image).to(device) | |
style = load_image(style_image, shape=content.shape[-2:]).to(device) | |
content_features = get_features(content, vgg) | |
style_features = get_features(style, vgg) | |
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features} | |
target = content.clone().requires_grad_(True).to(device) | |
style_weights = { | |
'conv1_1': conv1_1, | |
'conv2_1': conv2_1, | |
'conv3_1': conv3_1, | |
'conv4_1': conv4_1, | |
'conv5_1': conv5_1 | |
} | |
content_weight = alpha | |
style_weight = beta * 1e6 | |
optimizer = optim.Adam([target], lr=0.003) | |
intermediate_images = [] | |
show_every = steps // 9 # Show 9 intermediate images | |
for ii in range(1, steps+1): | |
target_features = get_features(target, vgg) | |
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2) | |
style_loss = 0 | |
for layer in style_weights: | |
target_feature = target_features[layer] | |
target_gram = gram_matrix(target_feature) | |
_, d, h, w = target_feature.shape | |
style_gram = style_grams[layer] | |
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2) | |
style_loss += layer_style_loss / (d * h * w) | |
total_loss = content_weight * content_loss + style_weight * style_loss | |
optimizer.zero_grad() | |
total_loss.backward() | |
optimizer.step() | |
if ii % show_every == 0 or ii == steps: | |
intermediate_images.append(im_convert(target)) | |
final_image = intermediate_images[-1] | |
intermediate_grid = create_grid(intermediate_images, 3, 3) | |
return final_image, intermediate_grid | |
def load_example(content, style, output): | |
return content, style, output | |
# Example images | |
examples = [ | |
["assets/content_1.jpg", "assets/style_1.jpg", "assets/result_1.png"], | |
["assets/content_2.jpg", "assets/style_2.jpg", "assets/result_2.png"], | |
["assets/content_3.png", "assets/style_3.jpg", "assets/result_3.png"], | |
] | |
#Load VGG19 model | |
vgg = models.vgg19(pretrained=True).features | |
for param in vgg.parameters(): | |
param.requires_grad_(False) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
vgg.to(device) | |
# Resize example images | |
resized_examples = [[resize_image(content), resize_image(style), resize_image(output)] for content, style, output in examples] | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Neural Style Transfer") | |
with gr.Row(): | |
with gr.Column(): | |
content_input = gr.Image(label="Content Image", type="numpy", image_mode="RGB", height=400, width=400) | |
style_input = gr.Image(label="Style Image", type="numpy", image_mode="RGB", height=400, width=400) | |
with gr.Column(): | |
output_image = gr.Image(label="Output Image") | |
intermediate_output = gr.Image(label="Intermediate Results") | |
with gr.Row(): | |
alpha_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Content Weight (α)") | |
beta_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Style Weight (β)") | |
with gr.Row(): | |
conv1_1_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Conv1_1 Weight") | |
conv2_1_slider = gr.Slider(minimum=0, maximum=1, value=0.8, step=0.1, label="Conv2_1 Weight") | |
conv3_1_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Conv3_1 Weight") | |
conv4_1_slider = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="Conv4_1 Weight") | |
conv5_1_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Conv5_1 Weight") | |
steps_slider = gr.Slider(minimum=1, maximum=2000, value=1000, step=100, label="Number of Steps") | |
run_button = gr.Button("Run Style Transfer") | |
run_button.click( | |
style_transfer, | |
inputs=[ | |
content_input, | |
style_input, | |
alpha_slider, | |
beta_slider, | |
conv1_1_slider, | |
conv2_1_slider, | |
conv3_1_slider, | |
conv4_1_slider, | |
conv5_1_slider, | |
steps_slider | |
], | |
outputs=[output_image, intermediate_output] | |
) | |
gr.Examples( | |
resized_examples, | |
inputs=[content_input, style_input, output_image], | |
outputs=[content_input, style_input, output_image], | |
fn=load_example, | |
cache_examples=True | |
) | |
demo.launch() |