IDKiro commited on
Commit
7eb34be
·
verified ·
1 Parent(s): 455a9aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -207
app.py CHANGED
@@ -1,8 +1,7 @@
 
1
  import random
2
  import numpy as np
3
  from PIL import Image
4
- import base64
5
- from io import BytesIO
6
 
7
  import torch
8
  import torchvision.transforms.functional as F
@@ -69,45 +68,40 @@ DEFAULT_STYLE_NAME = "No Style"
69
  MAX_SEED = np.iinfo(np.int32).max
70
 
71
 
72
- def pil_image_to_data_url(img, format="PNG"):
73
- buffered = BytesIO()
74
- img.save(buffered, format=format)
75
- img_str = base64.b64encode(buffered.getvalue()).decode()
76
- return f"data:image/{format.lower()};base64,{img_str}"
77
-
78
-
79
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
80
  if randomize_seed:
81
  seed = random.randint(0, MAX_SEED)
82
  return seed
83
 
84
 
 
85
  def run(
86
- image,
87
- prompt,
88
- prompt_template,
89
- style_name,
90
  controlnet_conditioning_scale,
91
  device_type="GPU",
92
- param_dtype='torch.float16',
93
  ):
94
  if device_type == "CPU":
95
- device = "cpu"
96
- param_dtype = 'torch.float32'
97
  else:
98
  device = "cuda"
99
-
100
- pipe.to(torch_device=device, torch_dtype=torch.float16 if param_dtype == 'torch.float16' else torch.float32)
 
 
 
101
 
102
  print(f"prompt: {prompt}")
103
  print("sketch updated")
104
  if image is None:
105
  ones = Image.new("L", (512, 512), 255)
106
- temp_url = pil_image_to_data_url(ones)
107
- return ones, gr.update(link=temp_url), gr.update(link=temp_url)
108
  prompt = prompt_template.replace("{prompt}", prompt)
109
- control_image = image.convert("RGB")
110
- control_image = Image.fromarray(255 - np.array(control_image))
111
 
112
  output_pil = pipe(
113
  prompt=prompt,
@@ -121,205 +115,84 @@ def run(
121
  controlnet_conditioning_scale=float(controlnet_conditioning_scale),
122
  ).images[0]
123
 
124
- input_sketch_url = pil_image_to_data_url(control_image)
125
- output_image_url = pil_image_to_data_url(output_pil)
126
- return (
127
- output_pil,
128
- gr.update(link=input_sketch_url),
129
- gr.update(link=output_image_url),
130
- )
131
-
132
-
133
- def update_canvas(use_line, use_eraser):
134
- if use_eraser:
135
- _color = "#ffffff"
136
- brush_size = 20
137
- if use_line:
138
- _color = "#000000"
139
- brush_size = 8
140
- return gr.update(brush_radius=brush_size, brush_color=_color, interactive=True)
141
-
142
-
143
- def upload_sketch(file):
144
- _img = Image.open(file.name)
145
- _img = _img.convert("L")
146
- return gr.update(value=_img, source="upload", interactive=True)
147
-
148
 
149
- scripts = """
150
- async () => {
151
- globalThis.theSketchDownloadFunction = () => {
152
- console.log("test")
153
- var link = document.createElement("a");
154
- dataUrl = document.getElementById('download_sketch').href
155
- link.setAttribute("href", dataUrl)
156
- link.setAttribute("download", "sketch.png")
157
- document.body.appendChild(link); // Required for Firefox
158
- link.click();
159
- document.body.removeChild(link); // Clean up
160
-
161
- // also call the output download function
162
- theOutputDownloadFunction();
163
- return false
164
- }
165
 
166
- globalThis.theOutputDownloadFunction = () => {
167
- console.log("test output download function")
168
- var link = document.createElement("a");
169
- dataUrl = document.getElementById('download_output').href
170
- link.setAttribute("href", dataUrl);
171
- link.setAttribute("download", "output.png");
172
- document.body.appendChild(link); // Required for Firefox
173
- link.click();
174
- document.body.removeChild(link); // Clean up
175
- return false
176
- }
177
-
178
- globalThis.UNDO_SKETCH_FUNCTION = () => {
179
- console.log("undo sketch function")
180
- var button_undo = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(1)');
181
- // Create a new 'click' event
182
- var event = new MouseEvent('click', {
183
- 'view': window,
184
- 'bubbles': true,
185
- 'cancelable': true
186
- });
187
- button_undo.dispatchEvent(event);
188
- }
189
-
190
- globalThis.DELETE_SKETCH_FUNCTION = () => {
191
- console.log("delete sketch function")
192
- var button_del = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(2)');
193
- // Create a new 'click' event
194
- var event = new MouseEvent('click', {
195
- 'view': window,
196
- 'bubbles': true,
197
- 'cancelable': true
198
- });
199
- button_del.dispatchEvent(event);
200
- }
201
-
202
- globalThis.togglePencil = () => {
203
- el_pencil = document.getElementById('my-toggle-pencil');
204
- el_pencil.classList.toggle('clicked');
205
- // simulate a click on the gradio button
206
- btn_gradio = document.querySelector("#cb-line > label > input");
207
- var event = new MouseEvent('click', {
208
- 'view': window,
209
- 'bubbles': true,
210
- 'cancelable': true
211
- });
212
- btn_gradio.dispatchEvent(event);
213
- if (el_pencil.classList.contains('clicked')) {
214
- document.getElementById('my-toggle-eraser').classList.remove('clicked');
215
- document.getElementById('my-div-pencil').style.backgroundColor = "gray";
216
- document.getElementById('my-div-eraser').style.backgroundColor = "white";
217
- }
218
- else {
219
- document.getElementById('my-toggle-eraser').classList.add('clicked');
220
- document.getElementById('my-div-pencil').style.backgroundColor = "white";
221
- document.getElementById('my-div-eraser').style.backgroundColor = "gray";
222
- }
223
-
224
- }
225
-
226
- globalThis.toggleEraser = () => {
227
- element = document.getElementById('my-toggle-eraser');
228
- element.classList.toggle('clicked');
229
- // simulate a click on the gradio button
230
- btn_gradio = document.querySelector("#cb-eraser > label > input");
231
- var event = new MouseEvent('click', {
232
- 'view': window,
233
- 'bubbles': true,
234
- 'cancelable': true
235
- });
236
- btn_gradio.dispatchEvent(event);
237
- if (element.classList.contains('clicked')) {
238
- document.getElementById('my-toggle-pencil').classList.remove('clicked');
239
- document.getElementById('my-div-pencil').style.backgroundColor = "white";
240
- document.getElementById('my-div-eraser').style.backgroundColor = "gray";
241
- }
242
- else {
243
- document.getElementById('my-toggle-pencil').classList.add('clicked');
244
- document.getElementById('my-div-pencil').style.backgroundColor = "gray";
245
- document.getElementById('my-div-eraser').style.backgroundColor = "white";
246
- }
247
- }
248
- }
249
- """
250
-
251
- with gr.Blocks(css="style.css") as demo:
252
  gr.Markdown("# SDXS-512-DreamShaper-Sketch")
253
  gr.Markdown("[SDXS: Real-Time One-Step Latent Diffusion Models with Image Conditions](https://arxiv.org/abs/2403.16627) | [GitHub](https://github.com/IDKiro/sdxs)")
254
- # these are hidden buttons that are used to trigger the canvas changes
255
- line = gr.Checkbox(label="line", value=False, elem_id="cb-line")
256
- eraser = gr.Checkbox(label="eraser", value=False, elem_id="cb-eraser")
257
  with gr.Row(elem_id="main_row"):
258
  with gr.Column(elem_id="column_input"):
259
  gr.Markdown("## INPUT", elem_id="input_header")
260
- image = gr.Image(
261
- source="canvas", tool="color-sketch", type="pil", image_mode="L",
262
- invert_colors=True, shape=(512, 512), brush_radius=8, height=440, width=440,
263
- brush_color="#000000", interactive=True, show_download_button=True, elem_id="input_image", show_label=False)
264
- download_sketch = gr.Button("Download sketch", scale=1, elem_id="download_sketch")
265
-
266
- gr.HTML("""
267
- <div class="button-row">
268
- <div id="my-div-pencil" class="pad2"> <button id="my-toggle-pencil" onclick="return togglePencil(this)"></button> </div>
269
- <div id="my-div-eraser" class="pad2"> <button id="my-toggle-eraser" onclick="return toggleEraser(this)"></button> </div>
270
- <div class="pad2"> <button id="my-button-undo" onclick="return UNDO_SKETCH_FUNCTION(this)"></button> </div>
271
- <div class="pad2"> <button id="my-button-clear" onclick="return DELETE_SKETCH_FUNCTION(this)"></button> </div>
272
- <div class="pad2"> <button href="TODO" download="image" id="my-button-down" onclick='return theSketchDownloadFunction()'></button> </div>
273
- </div>
274
- """)
275
  # gr.Markdown("## Prompt", elem_id="tools_header")
276
  prompt = gr.Textbox(label="Prompt", value="", show_label=True)
277
  with gr.Row():
278
- style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
279
- prompt_temp = gr.Textbox(label="Prompt Style Template", value=styles[DEFAULT_STYLE_NAME], scale=2, max_lines=1)
280
-
281
- controlnet_conditioning_scale = gr.Slider(label="Control Strength", minimum=0, maximum=1, step=0.01, value=0.8)
282
-
283
-
284
- device_choices = ['GPU','CPU']
285
- device_type = gr.Radio(device_choices, label='Device',
286
- value=device_choices[0],
287
- interactive=True,
288
- info='Many thanks to the community for the GPU!')
289
-
290
- dtype_choices = ['torch.float16','torch.float32']
291
- param_dtype = gr.Radio(dtype_choices,label='torch.weight_type',
292
- value=dtype_choices[0],
293
- interactive=True,
294
- info='To save GPU memory, use torch.float16. For better quality, use torch.float32.')
295
-
296
-
297
- with gr.Column(elem_id="column_process", min_width=50, scale=0.4):
298
- gr.Markdown("## SDXS-Sketch", elem_id="description")
299
- run_button = gr.Button("Run", min_width=50)
 
 
 
 
 
 
 
 
 
300
 
301
  with gr.Column(elem_id="column_output"):
302
  gr.Markdown("## OUTPUT", elem_id="output_header")
303
- result = gr.Image(label="Result", height=440, width=440, elem_id="output_image", show_label=False, show_download_button=True)
304
- download_output = gr.Button("Download output", elem_id="download_output")
305
- gr.Markdown("### Instructions")
306
- gr.Markdown("**1**. Enter a text prompt (e.g. cat)")
307
- gr.Markdown("**2**. Start sketching")
308
- gr.Markdown("**3**. Change the image style using a style template")
309
- gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider")
310
-
311
-
312
- eraser.change(fn=lambda x: gr.update(value=not x), inputs=[eraser], outputs=[line]).then(update_canvas, [line, eraser], [image])
313
- line.change(fn=lambda x: gr.update(value=not x), inputs=[line], outputs=[eraser]).then(update_canvas, [line, eraser], [image])
314
-
315
- demo.load(None,None,None,_js=scripts)
316
- inputs = [image, prompt, prompt_temp, style, controlnet_conditioning_scale, device_type, param_dtype]
317
- outputs = [result, download_sketch, download_output]
318
- prompt.submit(fn=run, inputs=inputs, outputs=outputs)
 
 
 
 
 
319
  style.change(lambda x: styles[x], inputs=[style], outputs=[prompt_temp]).then(
320
- fn=run, inputs=inputs, outputs=outputs,)
321
- run_button.click(fn=run, inputs=inputs, outputs=outputs)
322
  image.change(run, inputs=inputs, outputs=outputs,)
 
323
 
324
  if __name__ == "__main__":
325
- demo.queue().launch(debug=True)
 
1
+ import spaces
2
  import random
3
  import numpy as np
4
  from PIL import Image
 
 
5
 
6
  import torch
7
  import torchvision.transforms.functional as F
 
68
  MAX_SEED = np.iinfo(np.int32).max
69
 
70
 
 
 
 
 
 
 
 
71
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
72
  if randomize_seed:
73
  seed = random.randint(0, MAX_SEED)
74
  return seed
75
 
76
 
77
+ @spaces.GPU
78
  def run(
79
+ image,
80
+ prompt,
81
+ prompt_template,
82
+ style_name,
83
  controlnet_conditioning_scale,
84
  device_type="GPU",
85
+ param_dtype="torch.float16",
86
  ):
87
  if device_type == "CPU":
88
+ device = "cpu"
89
+ param_dtype = "torch.float32"
90
  else:
91
  device = "cuda"
92
+
93
+ pipe.to(
94
+ torch_device=device,
95
+ torch_dtype=torch.float16 if param_dtype == "torch.float16" else torch.float32,
96
+ )
97
 
98
  print(f"prompt: {prompt}")
99
  print("sketch updated")
100
  if image is None:
101
  ones = Image.new("L", (512, 512), 255)
102
+ return ones
 
103
  prompt = prompt_template.replace("{prompt}", prompt)
104
+ control_image = Image.fromarray(255 - np.array(image["composite"])[:, :, -1])
 
105
 
106
  output_pil = pipe(
107
  prompt=prompt,
 
115
  controlnet_conditioning_scale=float(controlnet_conditioning_scale),
116
  ).images[0]
117
 
118
+ return output_pil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  gr.Markdown("# SDXS-512-DreamShaper-Sketch")
123
  gr.Markdown("[SDXS: Real-Time One-Step Latent Diffusion Models with Image Conditions](https://arxiv.org/abs/2403.16627) | [GitHub](https://github.com/IDKiro/sdxs)")
 
 
 
124
  with gr.Row(elem_id="main_row"):
125
  with gr.Column(elem_id="column_input"):
126
  gr.Markdown("## INPUT", elem_id="input_header")
127
+ image = gr.Sketchpad(
128
+ type="pil",
129
+ image_mode="RGBA",
130
+ brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=8),
131
+ crop_size=(512, 512),
132
+ )
133
+
 
 
 
 
 
 
 
 
134
  # gr.Markdown("## Prompt", elem_id="tools_header")
135
  prompt = gr.Textbox(label="Prompt", value="", show_label=True)
136
  with gr.Row():
137
+ style = gr.Dropdown(
138
+ label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1
139
+ )
140
+ prompt_temp = gr.Textbox(
141
+ label="Prompt Style Template",
142
+ value=styles[DEFAULT_STYLE_NAME],
143
+ scale=2,
144
+ max_lines=1,
145
+ )
146
+
147
+ controlnet_conditioning_scale = gr.Slider(
148
+ label="Control Strength", minimum=0, maximum=1, step=0.01, value=0.8
149
+ )
150
+
151
+ device_choices = ["GPU", "CPU"]
152
+ device_type = gr.Radio(
153
+ device_choices,
154
+ label="Device",
155
+ value=device_choices[0],
156
+ interactive=True,
157
+ info="Many thanks to the community for the GPU!",
158
+ )
159
+
160
+ dtype_choices = ["torch.float16", "torch.float32"]
161
+ param_dtype = gr.Radio(
162
+ dtype_choices,
163
+ label="torch.weight_type",
164
+ value=dtype_choices[0],
165
+ interactive=True,
166
+ info="To save GPU memory, use torch.float16. For better quality, use torch.float32.",
167
+ )
168
 
169
  with gr.Column(elem_id="column_output"):
170
  gr.Markdown("## OUTPUT", elem_id="output_header")
171
+ result = gr.Image(
172
+ label="Result",
173
+ height=512,
174
+ width=512,
175
+ elem_id="output_image",
176
+ show_label=False,
177
+ show_download_button=True,
178
+ )
179
+
180
+ inputs = [
181
+ image,
182
+ prompt,
183
+ prompt_temp,
184
+ style,
185
+ controlnet_conditioning_scale,
186
+ device_type,
187
+ param_dtype,
188
+ ]
189
+ outputs = [result]
190
+
191
+ prompt.change(fn=run, inputs=inputs, outputs=outputs)
192
  style.change(lambda x: styles[x], inputs=[style], outputs=[prompt_temp]).then(
193
+ fn=run, inputs=inputs, outputs=outputs,)
 
194
  image.change(run, inputs=inputs, outputs=outputs,)
195
+ controlnet_conditioning_scale.change(run, inputs=inputs, outputs=outputs,)
196
 
197
  if __name__ == "__main__":
198
+ demo.queue().launch()