File size: 3,428 Bytes
5d82ba5
c6ad3ed
5d82ba5
 
 
c6ad3ed
5d82ba5
 
 
 
 
 
b0e960e
 
5d82ba5
 
 
 
 
 
 
 
 
 
 
b0e960e
5d82ba5
 
 
b0e960e
 
 
 
 
 
5d82ba5
b0e960e
 
 
 
 
 
5d82ba5
 
 
 
 
 
 
 
 
 
cb9a173
 
5d82ba5
 
 
b0e960e
 
 
 
5d82ba5
cb9a173
b0e960e
 
 
 
5d82ba5
 
 
 
 
 
 
 
 
 
b0e960e
 
5d82ba5
 
b0e960e
5d82ba5
b0e960e
5d82ba5
 
 
b0e960e
 
 
 
 
 
 
 
 
 
 
 
 
 
5d82ba5
b0e960e
5d82ba5
b0e960e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import gradio as gr
import torch
import yaml
import numpy as np

from munch import munchify
import torchvision.transforms as transforms
from torchvision.transforms import functional as F
from diffusers import (
    AutoPipelineForInpainting,
)
from generate_dataset import outpainting_generator_rectangle, merge_images_horizontally
from ddim_with_prob import DDIMSchedulerCustom

transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((512, 512), interpolation=F.InterpolationMode.LANCZOS),
    ])

def pref_inpainting(image,
                    box_width_ratio,
                    mask_random_start,
                    steps,
                    ):
    with open("./configs/paintreward_train_configs.yaml") as file:
        config_dict= yaml.safe_load(file)
        config = munchify(config_dict)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    
    pipe_ours = AutoPipelineForInpainting.from_pretrained(
                './model_ckpt', torch_dtype=torch.float16, variant='fp16')
    pipe_ours.scheduler = DDIMSchedulerCustom.from_config(pipe_ours.scheduler.config)

    pipe_runway = AutoPipelineForInpainting.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant='fp16')
    
    
    pipe_ours = pipe_ours.to(device)
    pipe_runway = pipe_runway.to(device)
    print('Loading pipeline')
   
    color, mask = outpainting_generator_rectangle(image, box_width_ratio/100, mask_random_start)
    mask = mask.convert('L')
   
    color, mask = np.array(color).transpose(2, 0, 1), np.array(mask)
    mask = mask[None, ...]
    mask_ = np.zeros_like(mask)
    mask_[mask < 125] = 0
    mask_[mask >= 125] = 1
    
    color = torch.from_numpy(color).to(device)
    mask = torch.from_numpy(mask).to(device)


    color, mask = transform(color), transform(mask)
    res_ours = pipe_ours(prompt='', image=color, mask_image=mask, eta=config.eta).images[0]
    print('Running inference ours')
    res_runway = pipe_runway(prompt="", image=color, mask_image=mask).images[0]
    print('Running inference runway')
    
    # res.save(os.path.join('./',  'test.png'))
    res_ours = merge_images_horizontally(color, res_ours)
    res_runway = merge_images_horizontally(color, res_runway)

    return res_ours, res_runway


inputs = [
    gr.Image(type="pil", image_mode="RGBA", label='Input Image'), # shape=[512, 512]
    gr.Slider(30, 45, value=35, step=1, label="box_width_ratio"),
    gr.Slider(0, 256, value=125, step=1, label="mask_random_start"),
    gr.Slider(30, 100, value=50, step=5, label="steps"),
]

outputs = [
    gr.Image(type="pil", image_mode="RGBA", label='PrefPaint', container=True, width="100%"),
    gr.Image(type="pil", image_mode="RGBA", label='RunwayPaint', container=True, width="100%"), 
]

files = os.listdir("./assets")
examples = [
    [f"./assets/{file_name}", 35, 125, 50] for file_name in files
]


with gr.Blocks() as demo:
   

    iface = gr.Interface(
        fn=pref_inpainting,
        inputs=inputs,
        outputs=outputs,
        title="Inpainting with Human Preference (Utilizing Free CPU Resources)",
        description="Upload an image and start your inpainting (currently only supporting outpainting masks; other mask types coming soon).",
        theme="default",
        examples=examples,
        # allow_flagging="never"
    )
   

    # iface.launch()

demo.launch()