muneebable's picture
Update app.py
5aacb28 verified
raw
history blame
8.89 kB
import gradio as gr
import torch
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import requests
from io import BytesIO
def load_image(img_input, max_size=400, shape=None):
if isinstance(img_input, np.ndarray):
# Convert numpy array to PIL Image
image = Image.fromarray(img_input.astype('uint8'), 'RGB')
elif isinstance(img_input, str):
if "http" in img_input:
response = requests.get(img_input)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(img_input).convert('RGB')
else:
raise ValueError("Unsupported input type. Expected numpy array or string.")
# large images will slow down processing
if max(image.size) > max_size:
size = max_size
else:
size = max(image.size)
if shape is not None:
size = shape
in_transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
# discard the transparent, alpha channel (that's the :3) and add the batch dimension
image = in_transform(image)[:3,:,:].unsqueeze(0)
return image
# helper function for un-normalizing an image
# and converting it from a Tensor image to a NumPy image for display
def im_convert(tensor):
""" Display a tensor as an image. """
image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1,2,0)
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1)
return image
def get_features(image, model, layers=None):
""" Run an image forward through a model and get the features for
a set of layers. Default layers are for VGGNet matching Gatys et al (2016)
"""
## TODO: Complete mapping layer names of PyTorch's VGGNet to names from the paper
## Need the layers for the content and style representations of an image
if layers is None:
layers = {'0': 'conv1_1',
'5': 'conv2_1',
'10': 'conv3_1',
'19': 'conv4_1',
'21': 'conv4_2', ## content representation
'28': 'conv5_1'}
## -- do not need to change the code below this line -- ##
features = {}
x = image
# model._modules is a dictionary holding each module in the model
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
def gram_matrix(tensor):
""" Calculate the Gram Matrix of a given tensor
Gram Matrix: https://en.wikipedia.org/wiki/Gramian_matrix
"""
## get the batch_size, depth, height, and width of the Tensor
## reshape it, so we're multiplying the features for each channel
## calculate the gram matrix
# get the batch_size, depth, height, and width of the Tensor
b, d, h, w = tensor.size()
# reshape so we're multiplying the features for each channel
tensor = tensor.view(b * d, h * w)
# calculate the gram matrix
gram = torch.mm(tensor, tensor.t())
return gram
# Function to resize image while maintaining aspect ratio
def resize_image(image_path, max_size=400):
img = Image.open(image_path).convert('RGB')
ratio = max_size / max(img.size)
new_size = tuple([int(x*ratio) for x in img.size])
img = img.resize(new_size, Image.Resampling.LANCZOS)
return np.array(img)
def create_grid(images, rows, cols):
assert len(images) == rows * cols, "Number of images doesn't match the grid size"
w, h = images[0].shape[1], images[0].shape[0]
grid = np.zeros((h*rows, w*cols, 3), dtype=np.uint8)
for i, img in enumerate(images):
r, c = divmod(i, cols)
grid[r*h:(r+1)*h, c*w:(c+1)*w] = img
return grid
def style_transfer(content_image, style_image, alpha, beta, conv1_1, conv2_1, conv3_1, conv4_1, conv5_1, steps):
content = load_image(content_image).to(device)
style = load_image(style_image, shape=content.shape[-2:]).to(device)
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
target = content.clone().requires_grad_(True).to(device)
style_weights = {
'conv1_1': conv1_1,
'conv2_1': conv2_1,
'conv3_1': conv3_1,
'conv4_1': conv4_1,
'conv5_1': conv5_1
}
content_weight = alpha
style_weight = beta * 1e6
optimizer = optim.Adam([target], lr=0.003)
intermediate_images = []
show_every = steps // 9 # Show 9 intermediate images
for ii in range(1, steps+1):
target_features = get_features(target, vgg)
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
style_loss = 0
for layer in style_weights:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
_, d, h, w = target_feature.shape
style_gram = style_grams[layer]
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
style_loss += layer_style_loss / (d * h * w)
total_loss = content_weight * content_loss + style_weight * style_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if ii % show_every == 0 or ii == steps:
intermediate_images.append(im_convert(target))
final_image = intermediate_images[-1]
intermediate_grid = create_grid(intermediate_images, 3, 3)
return final_image, intermediate_grid
def load_example(content, style, output):
return content, style, output
# Example images
examples = [
["assets/content_1.jpg", "assets/style_1.jpg", "assets/result_1.png"],
["assets/content_2.jpg", "assets/style_2.jpg", "assets/result_2.png"],
["assets/content_3.png", "assets/style_3.jpg", "assets/result_3.png"],
]
#Load VGG19 model
vgg = models.vgg19(pretrained=True).features
for param in vgg.parameters():
param.requires_grad_(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)
# Resize example images
resized_examples = [[resize_image(content), resize_image(style), resize_image(output)] for content, style, output in examples]
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Neural Style Transfer")
with gr.Row():
with gr.Column():
content_input = gr.Image(label="Content Image", type="numpy", image_mode="RGB", height=400, width=400)
style_input = gr.Image(label="Style Image", type="numpy", image_mode="RGB", height=400, width=400)
with gr.Column():
output_image = gr.Image(label="Output Image")
intermediate_output = gr.Image(label="Intermediate Results")
with gr.Row():
alpha_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Content Weight (α)")
beta_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Style Weight (β)")
with gr.Row():
conv1_1_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Conv1_1 Weight")
conv2_1_slider = gr.Slider(minimum=0, maximum=1, value=0.8, step=0.1, label="Conv2_1 Weight")
conv3_1_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Conv3_1 Weight")
conv4_1_slider = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="Conv4_1 Weight")
conv5_1_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Conv5_1 Weight")
steps_slider = gr.Slider(minimum=1, maximum=2000, value=1000, step=100, label="Number of Steps")
run_button = gr.Button("Run Style Transfer")
run_button.click(
style_transfer,
inputs=[
content_input,
style_input,
alpha_slider,
beta_slider,
conv1_1_slider,
conv2_1_slider,
conv3_1_slider,
conv4_1_slider,
conv5_1_slider,
steps_slider
],
outputs=[output_image, intermediate_output]
)
gr.Examples(
resized_examples,
inputs=[content_input, style_input, output_image],
outputs=[content_input, style_input, output_image],
fn=load_example,
cache_examples=True
)
demo.launch()