File size: 3,910 Bytes
c583015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ccac38
 
 
c583015
6ccac38
c583015
 
 
6ccac38
 
c583015
 
6ccac38
 
 
c583015
 
 
 
 
 
 
 
 
6ccac38
 
c583015
 
 
6ccac38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c583015
6ccac38
 
 
 
 
 
 
 
c583015
6ccac38
 
 
c583015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ccac38
 
c583015
 
 
 
 
6ccac38
c583015
6ccac38
c583015
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import torch
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from src.model import Model
import os

device = "cuda" if torch.cuda.is_available() else "cpu"


def denorm_img(img: torch.Tensor):
    std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
    mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
    return torch.clip(img * std + mean, min=0, max=1)


def main(inp1, inp2, alph, out_size=256):
    # print("inp1 ", inp1)
    # print("inp2 ", inp2)

    model = Model()
    model.load_state_dict(torch.load("./models/model_puddle.pt", map_location=torch.device(device)))
    model.eval()

    model.alpha = alph
    style = TF.to_tensor(inp1["composite"])
    content = TF.to_tensor(inp2["composite"])

    norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
        transforms.Resize(out_size, antialias=True)
    ])

    style, content = norm(style), norm(content)
    style, content = transform(style), transform(content)

    style, content = style.unsqueeze(0).to(device), content.unsqueeze(0).to(device)
    out = model(content, style)

    return denorm_img(out[0].detach()).permute(1, 2, 0).numpy()

def update_crop_size(crop_size):
    return gr.update(crop_size=(crop_size, crop_size))

with gr.Blocks() as demo:
    gr.Markdown("# Style Transfer with AdaIN")
    with gr.Row(variant="compact", equal_height=False):
        inp1 = gr.ImageEditor(
            type="pil",
            sources=["upload", "clipboard"],
            crop_size=(256, 256),
            eraser=False,
            brush=False,
            layers=False,
            label="Style",
            image_mode="RGB",
            transforms="crop",
            canvas_size=(512, 512)
        )
        inp2 = gr.ImageEditor(
            type="pil",
            sources=["upload", "clipboard"],
            crop_size=(256, 256),
            eraser=False,
            brush=False,
            layers=False,
            label="Content",
            image_mode="RGB",
            transforms="crop",
            canvas_size=(512, 512)
        )
        out = gr.Image(type="pil", label="Output")
    
    with gr.Row():
        out_size = gr.Dropdown(
            choices=[256, 512],
            value=256,
            multiselect=False,
            interactive=True,
            allow_custom_value=True,
            label="Output size",
            info="Size of the output image"
        )
        out_size.change(fn=update_crop_size, inputs=out_size, outputs=inp1)
        out_size.change(fn=update_crop_size, inputs=out_size, outputs=inp2)

        alph = gr.Slider(0, 1, value=1, label="Alpha", info="How much to change the original image", interactive=True, scale=3)

    with gr.Row():
        with gr.Column():
            gr.Markdown("## Style Examples")
            gr.Examples(
                examples=[
                    os.path.join(os.path.dirname(__file__), "data/styles/25.jpg"),
                    os.path.join(os.path.dirname(__file__), "data/styles/2272.jpg"),
                    os.path.join(os.path.dirname(__file__), "data/styles/2314.jpg"),
                ],
                inputs=inp1,
            )
        with gr.Column():
            gr.Markdown("## Content Examples")
            gr.Examples(
                examples=[
                    # os.path.join(os.path.dirname(__file__), "data/content/bear.jpg"),
                    os.path.join(os.path.dirname(__file__), "data/content/cat.jpg"),
                    os.path.join(os.path.dirname(__file__), "data/content/cow.jpg"),
                    os.path.join(os.path.dirname(__file__), "data/content/ducks.jpg"),
                ],
                inputs=inp2,
            )
    
    btn = gr.Button("Run")
    btn.click(fn=main, inputs=[inp1, inp2, alph, out_size], outputs=out)

demo.launch()