style-transfer / app.py
kuko6's picture
updated image inputs
6ccac38
import gradio as gr
import torch
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from src.model import Model
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
def denorm_img(img: torch.Tensor):
std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
return torch.clip(img * std + mean, min=0, max=1)
def main(inp1, inp2, alph, out_size=256):
# print("inp1 ", inp1)
# print("inp2 ", inp2)
model = Model()
model.load_state_dict(torch.load("./models/model_puddle.pt", map_location=torch.device(device)))
model.eval()
model.alpha = alph
style = TF.to_tensor(inp1["composite"])
content = TF.to_tensor(inp2["composite"])
norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
transforms.Resize(out_size, antialias=True)
])
style, content = norm(style), norm(content)
style, content = transform(style), transform(content)
style, content = style.unsqueeze(0).to(device), content.unsqueeze(0).to(device)
out = model(content, style)
return denorm_img(out[0].detach()).permute(1, 2, 0).numpy()
def update_crop_size(crop_size):
return gr.update(crop_size=(crop_size, crop_size))
with gr.Blocks() as demo:
gr.Markdown("# Style Transfer with AdaIN")
with gr.Row(variant="compact", equal_height=False):
inp1 = gr.ImageEditor(
type="pil",
sources=["upload", "clipboard"],
crop_size=(256, 256),
eraser=False,
brush=False,
layers=False,
label="Style",
image_mode="RGB",
transforms="crop",
canvas_size=(512, 512)
)
inp2 = gr.ImageEditor(
type="pil",
sources=["upload", "clipboard"],
crop_size=(256, 256),
eraser=False,
brush=False,
layers=False,
label="Content",
image_mode="RGB",
transforms="crop",
canvas_size=(512, 512)
)
out = gr.Image(type="pil", label="Output")
with gr.Row():
out_size = gr.Dropdown(
choices=[256, 512],
value=256,
multiselect=False,
interactive=True,
allow_custom_value=True,
label="Output size",
info="Size of the output image"
)
out_size.change(fn=update_crop_size, inputs=out_size, outputs=inp1)
out_size.change(fn=update_crop_size, inputs=out_size, outputs=inp2)
alph = gr.Slider(0, 1, value=1, label="Alpha", info="How much to change the original image", interactive=True, scale=3)
with gr.Row():
with gr.Column():
gr.Markdown("## Style Examples")
gr.Examples(
examples=[
os.path.join(os.path.dirname(__file__), "data/styles/25.jpg"),
os.path.join(os.path.dirname(__file__), "data/styles/2272.jpg"),
os.path.join(os.path.dirname(__file__), "data/styles/2314.jpg"),
],
inputs=inp1,
)
with gr.Column():
gr.Markdown("## Content Examples")
gr.Examples(
examples=[
# os.path.join(os.path.dirname(__file__), "data/content/bear.jpg"),
os.path.join(os.path.dirname(__file__), "data/content/cat.jpg"),
os.path.join(os.path.dirname(__file__), "data/content/cow.jpg"),
os.path.join(os.path.dirname(__file__), "data/content/ducks.jpg"),
],
inputs=inp2,
)
btn = gr.Button("Run")
btn.click(fn=main, inputs=[inp1, inp2, alph, out_size], outputs=out)
demo.launch()