noamrot commited on
Commit
b4741f5
1 Parent(s): bd0a0fe

Upload 2 files

Browse files

Uploaded scripts

Files changed (2) hide show
  1. app.py +237 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import gradio as gr
4
+ import torch
5
+ from PIL import Image, ImageOps
6
+ from diffusers import StableDiffusionInstructPix2PixPipeline
7
+ import spaces
8
+
9
+
10
+ help_text = """
11
+ Considerations while editing:
12
+ 1. The Base-Model, trained on the PIPE dataset, is great for some tasks, while the Finetuned-MB-Model, fine-tuned on the MagicBrush dataset, can be better for others. Please try both until you are satisfied.
13
+ 2. Image CFG controls how much to deviate from the original image. Higher values keep the image more consistent with the original.
14
+ 3. Text CFG does the opposite. Higher values lead to more changes in the image.
15
+ 4. Using different seed values will produce varied outputs.
16
+ 5. Increasing the number of steps can enhance the results.
17
+ 6. The Stable Diffusion autoencoder struggles with small faces in images.
18
+ """
19
+
20
+ article = """
21
+ <p style='text-align: center'>
22
+ <a href='https://arxiv.org/abs/2404.18212' target='_blank'>
23
+ Paint by Inpaint: Learning to Add Image Objects by Removing Them First</a>
24
+ </p>
25
+ """
26
+
27
+ description = """
28
+ <p style="text-align: center;">
29
+ Gradio demo for <strong>Paint by Inpaint: Learning to Add Image Objects by Removing Them First</strong>, visit our <a href='https://rotsteinnoam.github.io/Paint-by-Inpaint/' target='_blank'>project page</a>. <br>
30
+ The demo is both for models trained for image object addition using the <a href='https://huggingface.co/datasets/paint-by-inpaint/PIPE' target='_blank'>PIPE dataset</a> along with models trained with other datasets that are meant for general editing. <br>
31
+ </p>
32
+ """
33
+
34
+ # Base models
35
+ object_addition_base_model_id = "paint-by-inpaint/add-base"
36
+ general_editing_base_model_id = "paint-by-inpaint/general-base"
37
+
38
+ # MagicBrush finetuned models
39
+ object_addition_finetuned_model_id = "paint-by-inpaint/add-finetuned-mb"
40
+ general_editing_finetuned_model_id = "paint-by-inpaint/general-finetuned-mb"
41
+
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+
44
+ def load_model(model_id):
45
+ return StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
46
+
47
+ pipe_object_addition_base = load_model(object_addition_base_model_id)
48
+ pipe_object_addition_finetuned = load_model(object_addition_finetuned_model_id)
49
+
50
+ pipe_general_editing_base = load_model(general_editing_base_model_id)
51
+ pipe_general_editing_finetuned = load_model(general_editing_finetuned_model_id)
52
+
53
+ @spaces.GPU(duration=15)
54
+ def generate(
55
+ input_image: Image.Image,
56
+ instruction: str,
57
+ model_choice: int,
58
+ steps: int,
59
+ randomize_seed: bool,
60
+ seed: int,
61
+ text_cfg_scale: float,
62
+ image_cfg_scale: float,
63
+ task_type: str,
64
+ ):
65
+ seed = random.randint(0, 100000) if randomize_seed else seed
66
+ if task_type == "object_addition":
67
+ pipe = pipe_object_addition_base if model_choice == 0 else pipe_object_addition_finetuned
68
+ else:
69
+ pipe = pipe_general_editing_base if model_choice == 0 else pipe_general_editing_finetuned
70
+
71
+ width, height = input_image.size
72
+ factor = 512 / max(width, height)
73
+ factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
74
+ width = int((width * factor) // 64) * 64
75
+ height = int((height * factor) // 64) * 64
76
+ input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
77
+
78
+ if instruction == "":
79
+ return [input_image, seed]
80
+
81
+ generator = torch.manual_seed(seed)
82
+ edited_image = pipe(
83
+ instruction, image=input_image,
84
+ guidance_scale=text_cfg_scale, image_guidance_scale=image_cfg_scale,
85
+ num_inference_steps=steps, generator=generator,
86
+ ).images[0]
87
+ return [seed, text_cfg_scale, image_cfg_scale, edited_image]
88
+
89
+ def reset():
90
+ return [0, "Randomize Seed", 2024, "Fix CFG", 7.5, 1.5, None]
91
+
92
+ with gr.Blocks(css=".compact-box .gr-row { margin-bottom: 5px; } .compact-box .gr-number input, .compact-box .gr-radio label { padding: 5px 10px; }") as demo:
93
+ gr.HTML("""
94
+ <div style="text-align: center;">
95
+ <h1 style="font-weight: 900; margin-bottom: 7px;">Paint by Inpaint</h1>
96
+ {description}
97
+ </div>
98
+ """.format(description=description))
99
+
100
+ with gr.Tabs():
101
+ with gr.Tab("Object Addition"):
102
+ with gr.Row():
103
+ with gr.Column():
104
+ input_image = gr.Image(label="Input Image", type="pil", interactive=True)
105
+ instruction = gr.Textbox(lines=1, label="Addition Instruction", interactive=True, max_lines=1, placeholder="Enter addition instruction here")
106
+
107
+ model_choice = gr.Radio(
108
+ ["Base-Model", "Finetuned-MB-Model"],
109
+ value="Base-Model",
110
+ type="index",
111
+ label="Choose Model",
112
+ interactive=True,
113
+ )
114
+
115
+ with gr.Group(elem_id="compact-box"):
116
+ with gr.Row():
117
+ steps = gr.Number(value=50, precision=0, label="Steps", interactive=True)
118
+
119
+ with gr.Column():
120
+ with gr.Row():
121
+ seed = gr.Number(value=2024, precision=0, label="Seed", interactive=True)
122
+ randomize_seed = gr.Radio(
123
+ ["Fix Seed", "Randomize Seed"],
124
+ value="Randomize Seed",
125
+ type="index",
126
+ show_label=False,
127
+ interactive=True,
128
+ )
129
+
130
+ with gr.Row():
131
+ text_cfg_scale = gr.Number(value=7.5, label="Text CFG", interactive=True)
132
+ image_cfg_scale = gr.Number(value=1.5, label="Image CFG", interactive=True)
133
+
134
+ with gr.Row():
135
+ generate_button = gr.Button("Generate")
136
+ reset_button = gr.Button("Reset")
137
+
138
+ with gr.Column():
139
+ edited_image = gr.Image(label="Edited Image", type="pil", interactive=False)
140
+
141
+ generate_button.click(
142
+ fn=lambda *args: generate(*args, task_type="object_addition"),
143
+ inputs=[
144
+ input_image,
145
+ instruction,
146
+ model_choice,
147
+ steps,
148
+ randomize_seed,
149
+ seed,
150
+ text_cfg_scale,
151
+ image_cfg_scale,
152
+ ],
153
+ outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image],
154
+ )
155
+ reset_button.click(
156
+ fn=reset,
157
+ inputs=[],
158
+ outputs=[steps, randomize_seed, seed, text_cfg_scale, image_cfg_scale, edited_image],
159
+ )
160
+
161
+ with gr.Tab("General Editing"):
162
+ with gr.Row():
163
+ with gr.Column():
164
+ input_image_editing = gr.Image(label="Input Image", type="pil", interactive=True)
165
+ instruction_editing = gr.Textbox(lines=1, label="Editing Instruction", interactive=True, max_lines=1, placeholder="Enter editing instruction here")
166
+
167
+ model_choice_editing = gr.Radio(
168
+ ["Base-Model", "Finetuned-MB-Model"],
169
+ value="Base-Model",
170
+ type="index",
171
+ label="Choose Model",
172
+ interactive=True,
173
+ )
174
+
175
+ with gr.Group(elem_id="compact-box"):
176
+ with gr.Row():
177
+ steps_editing = gr.Number(value=50, precision=0, label="Steps", interactive=True)
178
+
179
+ with gr.Column():
180
+ with gr.Row():
181
+ seed_editing = gr.Number(value=2024, precision=0, label="Seed", interactive=True)
182
+ randomize_seed_editing = gr.Radio(
183
+ ["Fix Seed", "Randomize Seed"],
184
+ value="Randomize Seed",
185
+ type="index",
186
+ show_label=False,
187
+ interactive=True,
188
+ )
189
+
190
+ with gr.Row():
191
+ text_cfg_scale_editing = gr.Number(value=7.5, label="Text CFG", interactive=True)
192
+ image_cfg_scale_editing = gr.Number(value=1.5, label="Image CFG", interactive=True)
193
+
194
+ with gr.Row():
195
+ generate_button_editing = gr.Button("Generate")
196
+ reset_button_editing = gr.Button("Reset")
197
+
198
+ with gr.Column():
199
+ edited_image_editing = gr.Image(label="Edited Image", type="pil", interactive=False)
200
+
201
+ generate_button_editing.click(
202
+ fn=lambda *args: generate(*args, task_type="general_editing"),
203
+ inputs=[
204
+ input_image_editing,
205
+ instruction_editing,
206
+ model_choice_editing,
207
+ steps_editing,
208
+ randomize_seed_editing,
209
+ seed_editing,
210
+ text_cfg_scale_editing,
211
+ image_cfg_scale_editing,
212
+ ],
213
+ outputs=[seed_editing, text_cfg_scale_editing, image_cfg_scale_editing, edited_image_editing],
214
+ )
215
+ reset_button_editing.click(
216
+ fn=reset,
217
+ inputs=[],
218
+ outputs=[steps_editing, randomize_seed_editing, seed_editing, text_cfg_scale_editing, image_cfg_scale_editing, edited_image_editing],
219
+ )
220
+
221
+ gr.Markdown(help_text)
222
+
223
+ examples = [
224
+ ["examples/messi.jpeg", "Add a royal silver crown"],
225
+ ["examples/coffee.jpg", "Add steamed milk"],
226
+ ]
227
+
228
+ gr.Examples(
229
+ examples=examples,
230
+ inputs=[input_image, instruction],
231
+ outputs=[edited_image],
232
+ )
233
+
234
+ gr.HTML(article)
235
+
236
+ demo.queue()
237
+ demo.launch(share=False, max_threads=1)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==4.36.0
2
+ torch==2.2.0
3
+ Pillow==10.2.0
4
+ diffusers
5
+ spaces==0.28.3