style-transfer / app.py
kuko6's picture
updated image inputs
6ccac38
raw
history blame
3.91 kB
import gradio as gr
import torch
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from src.model import Model
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
def denorm_img(img: torch.Tensor):
std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
return torch.clip(img * std + mean, min=0, max=1)
def main(inp1, inp2, alph, out_size=256):
# print("inp1 ", inp1)
# print("inp2 ", inp2)
model = Model()
model.load_state_dict(torch.load("./models/model_puddle.pt", map_location=torch.device(device)))
model.eval()
model.alpha = alph
style = TF.to_tensor(inp1["composite"])
content = TF.to_tensor(inp2["composite"])
norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
transforms.Resize(out_size, antialias=True)
])
style, content = norm(style), norm(content)
style, content = transform(style), transform(content)
style, content = style.unsqueeze(0).to(device), content.unsqueeze(0).to(device)
out = model(content, style)
return denorm_img(out[0].detach()).permute(1, 2, 0).numpy()
def update_crop_size(crop_size):
return gr.update(crop_size=(crop_size, crop_size))
with gr.Blocks() as demo:
gr.Markdown("# Style Transfer with AdaIN")
with gr.Row(variant="compact", equal_height=False):
inp1 = gr.ImageEditor(
type="pil",
sources=["upload", "clipboard"],
crop_size=(256, 256),
eraser=False,
brush=False,
layers=False,
label="Style",
image_mode="RGB",
transforms="crop",
canvas_size=(512, 512)
)
inp2 = gr.ImageEditor(
type="pil",
sources=["upload", "clipboard"],
crop_size=(256, 256),
eraser=False,
brush=False,
layers=False,
label="Content",
image_mode="RGB",
transforms="crop",
canvas_size=(512, 512)
)
out = gr.Image(type="pil", label="Output")
with gr.Row():
out_size = gr.Dropdown(
choices=[256, 512],
value=256,
multiselect=False,
interactive=True,
allow_custom_value=True,
label="Output size",
info="Size of the output image"
)
out_size.change(fn=update_crop_size, inputs=out_size, outputs=inp1)
out_size.change(fn=update_crop_size, inputs=out_size, outputs=inp2)
alph = gr.Slider(0, 1, value=1, label="Alpha", info="How much to change the original image", interactive=True, scale=3)
with gr.Row():
with gr.Column():
gr.Markdown("## Style Examples")
gr.Examples(
examples=[
os.path.join(os.path.dirname(__file__), "data/styles/25.jpg"),
os.path.join(os.path.dirname(__file__), "data/styles/2272.jpg"),
os.path.join(os.path.dirname(__file__), "data/styles/2314.jpg"),
],
inputs=inp1,
)
with gr.Column():
gr.Markdown("## Content Examples")
gr.Examples(
examples=[
# os.path.join(os.path.dirname(__file__), "data/content/bear.jpg"),
os.path.join(os.path.dirname(__file__), "data/content/cat.jpg"),
os.path.join(os.path.dirname(__file__), "data/content/cow.jpg"),
os.path.join(os.path.dirname(__file__), "data/content/ducks.jpg"),
],
inputs=inp2,
)
btn = gr.Button("Run")
btn.click(fn=main, inputs=[inp1, inp2, alph, out_size], outputs=out)
demo.launch()