SIGMitch commited on
Commit
95b1a30
·
verified ·
1 Parent(s): 1b18660

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -39
app.py CHANGED
@@ -1,51 +1,241 @@
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
-
4
  import torch
5
- from diffusers import FluxPipeline
6
- from diffusers import FluxImg2ImgPipeline
7
- from diffusers.utils import load_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
9
 
10
- from huggingface_hub.utils import RepositoryNotFoundError
 
11
 
12
- pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float16).to("cuda")
13
- pipelineImg = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float16).to("cuda")
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- @spaces.GPU(duration=70)
18
- def generate(image, prompt, negative_prompt, width, height, sample_steps, lora_id):
19
- try:
20
- # pipeline.load_lora_weights(lora_id)
21
- init_image = load_image(image).resize((1024, 1024))
22
- pipelineImg.load_lora_weights(lora_id)
23
- except RepositoryNotFoundError:
24
- raise ValueError(f"Recieved invalid FLUX LoRA.")
25
-
26
- return pipeline(prompt=f"{prompt}\nDO NOT INCLUDE {negative_prompt}", image=init_image, width=width, height=height, num_inference_steps=sample_steps, generator=torch.Generator("cpu").manual_seed(42), guidance_scale=7).images[0]
27
 
28
- with gr.Blocks() as interface:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  with gr.Column():
 
 
 
 
 
 
 
 
30
  with gr.Row():
31
- with gr.Column():
32
- image = gr.Image(label="Input image", show_label=False, type="filepath")
33
- prompt = gr.Textbox(label="Prompt", info="What do you want?", value="Keanu Reeves holding a neon sign reading 'Hello, world!', 32k HDR, paparazzi", lines=4, interactive=True)
34
- negative_prompt = gr.Textbox(label="Negative Prompt", info="What do you want to exclude from the image?", value="ugly, low quality", lines=4, interactive=True)
35
- with gr.Column():
36
- generate_button = gr.Button("Generate")
37
- output = gr.Image()
38
- with gr.Row():
39
- with gr.Accordion(label="Advanced Settings", open=False):
40
- with gr.Row():
41
- with gr.Column():
42
- width = gr.Slider(label="Width", info="The width in pixels of the generated image.", value=512, minimum=128, maximum=4096, step=64, interactive=True)
43
- height = gr.Slider(label="Height", info="The height in pixels of the generated image.", value=512, minimum=128, maximum=4096, step=64, interactive=True)
44
- with gr.Column():
45
- sampling_steps = gr.Slider(label="Sampling Steps", info="The number of denoising steps.", value=20, minimum=4, maximum=50, step=1, interactive=True)
46
- lora_id = gr.Textbox(label="Adapter Repository", info="ID of the FLUX LoRA", value="pepper13/fluxfw")
47
-
48
- generate_button.click(fn=generate, inputs=[image, prompt, negative_prompt, width, height, sampling_steps, lora_id], outputs=[output])
49
-
50
- if __name__ == "__main__":
51
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple
3
+
4
+ import requests
5
+ import random
6
+ import numpy as np
7
  import gradio as gr
8
  import spaces
 
9
  import torch
10
+ from PIL import Image
11
+ from diffusers import FluxInpaintPipeline
12
+
13
+ MARKDOWN = """
14
+ # FLUX.1 Inpainting 🔥
15
+ Shoutout to [Black Forest Labs](https://huggingface.co/black-forest-labs) team for
16
+ creating this amazing model, and a big thanks to [Gothos](https://github.com/Gothos)
17
+ for taking it to the next level by enabling inpainting with the FLUX.
18
+ """
19
+
20
+ MAX_SEED = np.iinfo(np.int32).max
21
+ IMAGE_SIZE = 1024
22
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
+
24
+
25
+ def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
26
+ image = image.convert("RGBA")
27
+ data = image.getdata()
28
+ new_data = []
29
+ for item in data:
30
+ avg = sum(item[:3]) / 3
31
+ if avg < threshold:
32
+ new_data.append((0, 0, 0, 0))
33
+ else:
34
+ new_data.append(item)
35
+
36
+ image.putdata(new_data)
37
+ return image
38
+
39
+
40
+ EXAMPLES = [
41
+ [
42
+ {
43
+ "background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw),
44
+ "layers": [remove_background(Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-2.png", stream=True).raw))],
45
+ "composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-2.png", stream=True).raw),
46
+ },
47
+ "little lion",
48
+ 42,
49
+ False,
50
+ 0.85,
51
+ 30
52
+ ],
53
+ [
54
+ {
55
+ "background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw),
56
+ "layers": [remove_background(Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-3.png", stream=True).raw))],
57
+ "composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-3.png", stream=True).raw),
58
+ },
59
+ "tribal tattoos",
60
+ 42,
61
+ False,
62
+ 0.85,
63
+ 30
64
+ ]
65
+ ]
66
+
67
+ pipe = FluxInpaintPipeline.from_pretrained(
68
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
69
+
70
+
71
+ def resize_image_dimensions(
72
+ original_resolution_wh: Tuple[int, int],
73
+ maximum_dimension: int = IMAGE_SIZE
74
+ ) -> Tuple[int, int]:
75
+ width, height = original_resolution_wh
76
+
77
+ # if width <= maximum_dimension and height <= maximum_dimension:
78
+ # width = width - (width % 32)
79
+ # height = height - (height % 32)
80
+ # return width, height
81
+
82
+ if width > height:
83
+ scaling_factor = maximum_dimension / width
84
+ else:
85
+ scaling_factor = maximum_dimension / height
86
 
87
+ new_width = int(width * scaling_factor)
88
+ new_height = int(height * scaling_factor)
89
 
90
+ new_width = new_width - (new_width % 32)
91
+ new_height = new_height - (new_height % 32)
92
 
93
+ return new_width, new_height
 
94
 
95
 
96
+ @spaces.GPU(duration=100)
97
+ def process(
98
+ input_image_editor: dict,
99
+ input_text: str,
100
+ seed_slicer: int,
101
+ randomize_seed_checkbox: bool,
102
+ strength_slider: float,
103
+ num_inference_steps_slider: int,
104
+ progress=gr.Progress(track_tqdm=True)
105
+ ):
106
+ if not input_text:
107
+ gr.Info("Please enter a text prompt.")
108
+ return None, None
109
 
110
+ image = input_image_editor['background']
111
+ mask = input_image_editor['layers'][0]
 
 
 
 
 
 
 
 
112
 
113
+ if not image:
114
+ gr.Info("Please upload an image.")
115
+ return None, None
116
+
117
+ if not mask:
118
+ gr.Info("Please draw a mask on the image.")
119
+ return None, None
120
+
121
+ width, height = resize_image_dimensions(original_resolution_wh=image.size)
122
+ resized_image = image.resize((width, height), Image.LANCZOS)
123
+ resized_mask = mask.resize((width, height), Image.LANCZOS)
124
+
125
+ if randomize_seed_checkbox:
126
+ seed_slicer = random.randint(0, MAX_SEED)
127
+ generator = torch.Generator().manual_seed(seed_slicer)
128
+ pipe.load_lora_weights("SIGMitch/KIT")
129
+ result = pipe(
130
+ prompt=input_text,
131
+ image=resized_image,
132
+ mask_image=resized_mask,
133
+ width=width,
134
+ height=height,
135
+ strength=strength_slider,
136
+ generator=generator,
137
+ num_inference_steps=num_inference_steps_slider
138
+ ).images[0]
139
+ print('INFERENCE DONE')
140
+ return result, resized_mask
141
+
142
+
143
+ with gr.Blocks() as demo:
144
+ gr.Markdown(MARKDOWN)
145
+ with gr.Row():
146
  with gr.Column():
147
+ input_image_editor_component = gr.ImageEditor(
148
+ label='Image',
149
+ type='pil',
150
+ sources=["upload", "webcam"],
151
+ image_mode='RGB',
152
+ layers=False,
153
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
154
+
155
  with gr.Row():
156
+ input_text_component = gr.Text(
157
+ label="Prompt",
158
+ show_label=False,
159
+ max_lines=1,
160
+ placeholder="Enter your prompt",
161
+ container=False,
162
+ )
163
+ submit_button_component = gr.Button(
164
+ value='Submit', variant='primary', scale=0)
165
+
166
+ with gr.Accordion("Advanced Settings", open=False):
167
+ seed_slicer_component = gr.Slider(
168
+ label="Seed",
169
+ minimum=0,
170
+ maximum=MAX_SEED,
171
+ step=1,
172
+ value=42,
173
+ )
174
+
175
+ randomize_seed_checkbox_component = gr.Checkbox(
176
+ label="Randomize seed", value=True)
177
+
178
+ with gr.Row():
179
+ strength_slider_component = gr.Slider(
180
+ label="Strength",
181
+ info="Indicates extent to transform the reference `image`. "
182
+ "Must be between 0 and 1. `image` is used as a starting "
183
+ "point and more noise is added the higher the `strength`.",
184
+ minimum=0,
185
+ maximum=1,
186
+ step=0.01,
187
+ value=0.85,
188
+ )
189
+
190
+ num_inference_steps_slider_component = gr.Slider(
191
+ label="Number of inference steps",
192
+ info="The number of denoising steps. More denoising steps "
193
+ "usually lead to a higher quality image at the",
194
+ minimum=1,
195
+ maximum=50,
196
+ step=1,
197
+ value=20,
198
+ )
199
+ with gr.Column():
200
+ output_image_component = gr.Image(
201
+ type='pil', image_mode='RGB', label='Generated image', format="png")
202
+ with gr.Accordion("Debug", open=False):
203
+ output_mask_component = gr.Image(
204
+ type='pil', image_mode='RGB', label='Input mask', format="png")
205
+ with gr.Row():
206
+ gr.Examples(
207
+ fn=process,
208
+ examples=EXAMPLES,
209
+ inputs=[
210
+ input_image_editor_component,
211
+ input_text_component,
212
+ seed_slicer_component,
213
+ randomize_seed_checkbox_component,
214
+ strength_slider_component,
215
+ num_inference_steps_slider_component
216
+ ],
217
+ outputs=[
218
+ output_image_component,
219
+ output_mask_component
220
+ ],
221
+ run_on_click=True,
222
+ cache_examples=True
223
+ )
224
+
225
+ submit_button_component.click(
226
+ fn=process,
227
+ inputs=[
228
+ input_image_editor_component,
229
+ input_text_component,
230
+ seed_slicer_component,
231
+ randomize_seed_checkbox_component,
232
+ strength_slider_component,
233
+ num_inference_steps_slider_component
234
+ ],
235
+ outputs=[
236
+ output_image_component,
237
+ output_mask_component
238
+ ]
239
+ )
240
+
241
+ demo.launch(debug=False, show_error=True)