Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,7 @@
|
|
|
|
1 |
import random
|
2 |
import numpy as np
|
3 |
from PIL import Image
|
4 |
-
import base64
|
5 |
-
from io import BytesIO
|
6 |
|
7 |
import torch
|
8 |
import torchvision.transforms.functional as F
|
@@ -69,45 +68,40 @@ DEFAULT_STYLE_NAME = "No Style"
|
|
69 |
MAX_SEED = np.iinfo(np.int32).max
|
70 |
|
71 |
|
72 |
-
def pil_image_to_data_url(img, format="PNG"):
|
73 |
-
buffered = BytesIO()
|
74 |
-
img.save(buffered, format=format)
|
75 |
-
img_str = base64.b64encode(buffered.getvalue()).decode()
|
76 |
-
return f"data:image/{format.lower()};base64,{img_str}"
|
77 |
-
|
78 |
-
|
79 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
80 |
if randomize_seed:
|
81 |
seed = random.randint(0, MAX_SEED)
|
82 |
return seed
|
83 |
|
84 |
|
|
|
85 |
def run(
|
86 |
-
image,
|
87 |
-
prompt,
|
88 |
-
prompt_template,
|
89 |
-
style_name,
|
90 |
controlnet_conditioning_scale,
|
91 |
device_type="GPU",
|
92 |
-
param_dtype=
|
93 |
):
|
94 |
if device_type == "CPU":
|
95 |
-
device = "cpu"
|
96 |
-
param_dtype =
|
97 |
else:
|
98 |
device = "cuda"
|
99 |
-
|
100 |
-
pipe.to(
|
|
|
|
|
|
|
101 |
|
102 |
print(f"prompt: {prompt}")
|
103 |
print("sketch updated")
|
104 |
if image is None:
|
105 |
ones = Image.new("L", (512, 512), 255)
|
106 |
-
|
107 |
-
return ones, gr.update(link=temp_url), gr.update(link=temp_url)
|
108 |
prompt = prompt_template.replace("{prompt}", prompt)
|
109 |
-
control_image =
|
110 |
-
control_image = Image.fromarray(255 - np.array(control_image))
|
111 |
|
112 |
output_pil = pipe(
|
113 |
prompt=prompt,
|
@@ -121,205 +115,84 @@ def run(
|
|
121 |
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
|
122 |
).images[0]
|
123 |
|
124 |
-
|
125 |
-
output_image_url = pil_image_to_data_url(output_pil)
|
126 |
-
return (
|
127 |
-
output_pil,
|
128 |
-
gr.update(link=input_sketch_url),
|
129 |
-
gr.update(link=output_image_url),
|
130 |
-
)
|
131 |
-
|
132 |
-
|
133 |
-
def update_canvas(use_line, use_eraser):
|
134 |
-
if use_eraser:
|
135 |
-
_color = "#ffffff"
|
136 |
-
brush_size = 20
|
137 |
-
if use_line:
|
138 |
-
_color = "#000000"
|
139 |
-
brush_size = 8
|
140 |
-
return gr.update(brush_radius=brush_size, brush_color=_color, interactive=True)
|
141 |
-
|
142 |
-
|
143 |
-
def upload_sketch(file):
|
144 |
-
_img = Image.open(file.name)
|
145 |
-
_img = _img.convert("L")
|
146 |
-
return gr.update(value=_img, source="upload", interactive=True)
|
147 |
-
|
148 |
|
149 |
-
scripts = """
|
150 |
-
async () => {
|
151 |
-
globalThis.theSketchDownloadFunction = () => {
|
152 |
-
console.log("test")
|
153 |
-
var link = document.createElement("a");
|
154 |
-
dataUrl = document.getElementById('download_sketch').href
|
155 |
-
link.setAttribute("href", dataUrl)
|
156 |
-
link.setAttribute("download", "sketch.png")
|
157 |
-
document.body.appendChild(link); // Required for Firefox
|
158 |
-
link.click();
|
159 |
-
document.body.removeChild(link); // Clean up
|
160 |
-
|
161 |
-
// also call the output download function
|
162 |
-
theOutputDownloadFunction();
|
163 |
-
return false
|
164 |
-
}
|
165 |
|
166 |
-
|
167 |
-
console.log("test output download function")
|
168 |
-
var link = document.createElement("a");
|
169 |
-
dataUrl = document.getElementById('download_output').href
|
170 |
-
link.setAttribute("href", dataUrl);
|
171 |
-
link.setAttribute("download", "output.png");
|
172 |
-
document.body.appendChild(link); // Required for Firefox
|
173 |
-
link.click();
|
174 |
-
document.body.removeChild(link); // Clean up
|
175 |
-
return false
|
176 |
-
}
|
177 |
-
|
178 |
-
globalThis.UNDO_SKETCH_FUNCTION = () => {
|
179 |
-
console.log("undo sketch function")
|
180 |
-
var button_undo = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(1)');
|
181 |
-
// Create a new 'click' event
|
182 |
-
var event = new MouseEvent('click', {
|
183 |
-
'view': window,
|
184 |
-
'bubbles': true,
|
185 |
-
'cancelable': true
|
186 |
-
});
|
187 |
-
button_undo.dispatchEvent(event);
|
188 |
-
}
|
189 |
-
|
190 |
-
globalThis.DELETE_SKETCH_FUNCTION = () => {
|
191 |
-
console.log("delete sketch function")
|
192 |
-
var button_del = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(2)');
|
193 |
-
// Create a new 'click' event
|
194 |
-
var event = new MouseEvent('click', {
|
195 |
-
'view': window,
|
196 |
-
'bubbles': true,
|
197 |
-
'cancelable': true
|
198 |
-
});
|
199 |
-
button_del.dispatchEvent(event);
|
200 |
-
}
|
201 |
-
|
202 |
-
globalThis.togglePencil = () => {
|
203 |
-
el_pencil = document.getElementById('my-toggle-pencil');
|
204 |
-
el_pencil.classList.toggle('clicked');
|
205 |
-
// simulate a click on the gradio button
|
206 |
-
btn_gradio = document.querySelector("#cb-line > label > input");
|
207 |
-
var event = new MouseEvent('click', {
|
208 |
-
'view': window,
|
209 |
-
'bubbles': true,
|
210 |
-
'cancelable': true
|
211 |
-
});
|
212 |
-
btn_gradio.dispatchEvent(event);
|
213 |
-
if (el_pencil.classList.contains('clicked')) {
|
214 |
-
document.getElementById('my-toggle-eraser').classList.remove('clicked');
|
215 |
-
document.getElementById('my-div-pencil').style.backgroundColor = "gray";
|
216 |
-
document.getElementById('my-div-eraser').style.backgroundColor = "white";
|
217 |
-
}
|
218 |
-
else {
|
219 |
-
document.getElementById('my-toggle-eraser').classList.add('clicked');
|
220 |
-
document.getElementById('my-div-pencil').style.backgroundColor = "white";
|
221 |
-
document.getElementById('my-div-eraser').style.backgroundColor = "gray";
|
222 |
-
}
|
223 |
-
|
224 |
-
}
|
225 |
-
|
226 |
-
globalThis.toggleEraser = () => {
|
227 |
-
element = document.getElementById('my-toggle-eraser');
|
228 |
-
element.classList.toggle('clicked');
|
229 |
-
// simulate a click on the gradio button
|
230 |
-
btn_gradio = document.querySelector("#cb-eraser > label > input");
|
231 |
-
var event = new MouseEvent('click', {
|
232 |
-
'view': window,
|
233 |
-
'bubbles': true,
|
234 |
-
'cancelable': true
|
235 |
-
});
|
236 |
-
btn_gradio.dispatchEvent(event);
|
237 |
-
if (element.classList.contains('clicked')) {
|
238 |
-
document.getElementById('my-toggle-pencil').classList.remove('clicked');
|
239 |
-
document.getElementById('my-div-pencil').style.backgroundColor = "white";
|
240 |
-
document.getElementById('my-div-eraser').style.backgroundColor = "gray";
|
241 |
-
}
|
242 |
-
else {
|
243 |
-
document.getElementById('my-toggle-pencil').classList.add('clicked');
|
244 |
-
document.getElementById('my-div-pencil').style.backgroundColor = "gray";
|
245 |
-
document.getElementById('my-div-eraser').style.backgroundColor = "white";
|
246 |
-
}
|
247 |
-
}
|
248 |
-
}
|
249 |
-
"""
|
250 |
-
|
251 |
-
with gr.Blocks(css="style.css") as demo:
|
252 |
gr.Markdown("# SDXS-512-DreamShaper-Sketch")
|
253 |
gr.Markdown("[SDXS: Real-Time One-Step Latent Diffusion Models with Image Conditions](https://arxiv.org/abs/2403.16627) | [GitHub](https://github.com/IDKiro/sdxs)")
|
254 |
-
# these are hidden buttons that are used to trigger the canvas changes
|
255 |
-
line = gr.Checkbox(label="line", value=False, elem_id="cb-line")
|
256 |
-
eraser = gr.Checkbox(label="eraser", value=False, elem_id="cb-eraser")
|
257 |
with gr.Row(elem_id="main_row"):
|
258 |
with gr.Column(elem_id="column_input"):
|
259 |
gr.Markdown("## INPUT", elem_id="input_header")
|
260 |
-
image = gr.
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
<div class="button-row">
|
268 |
-
<div id="my-div-pencil" class="pad2"> <button id="my-toggle-pencil" onclick="return togglePencil(this)"></button> </div>
|
269 |
-
<div id="my-div-eraser" class="pad2"> <button id="my-toggle-eraser" onclick="return toggleEraser(this)"></button> </div>
|
270 |
-
<div class="pad2"> <button id="my-button-undo" onclick="return UNDO_SKETCH_FUNCTION(this)"></button> </div>
|
271 |
-
<div class="pad2"> <button id="my-button-clear" onclick="return DELETE_SKETCH_FUNCTION(this)"></button> </div>
|
272 |
-
<div class="pad2"> <button href="TODO" download="image" id="my-button-down" onclick='return theSketchDownloadFunction()'></button> </div>
|
273 |
-
</div>
|
274 |
-
""")
|
275 |
# gr.Markdown("## Prompt", elem_id="tools_header")
|
276 |
prompt = gr.Textbox(label="Prompt", value="", show_label=True)
|
277 |
with gr.Row():
|
278 |
-
style = gr.Dropdown(
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
|
301 |
with gr.Column(elem_id="column_output"):
|
302 |
gr.Markdown("## OUTPUT", elem_id="output_header")
|
303 |
-
result = gr.Image(
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
|
|
|
|
|
|
|
|
|
|
319 |
style.change(lambda x: styles[x], inputs=[style], outputs=[prompt_temp]).then(
|
320 |
-
|
321 |
-
run_button.click(fn=run, inputs=inputs, outputs=outputs)
|
322 |
image.change(run, inputs=inputs, outputs=outputs,)
|
|
|
323 |
|
324 |
if __name__ == "__main__":
|
325 |
-
demo.queue().launch(
|
|
|
1 |
+
import spaces
|
2 |
import random
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
|
|
|
|
5 |
|
6 |
import torch
|
7 |
import torchvision.transforms.functional as F
|
|
|
68 |
MAX_SEED = np.iinfo(np.int32).max
|
69 |
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
72 |
if randomize_seed:
|
73 |
seed = random.randint(0, MAX_SEED)
|
74 |
return seed
|
75 |
|
76 |
|
77 |
+
@spaces.GPU
|
78 |
def run(
|
79 |
+
image,
|
80 |
+
prompt,
|
81 |
+
prompt_template,
|
82 |
+
style_name,
|
83 |
controlnet_conditioning_scale,
|
84 |
device_type="GPU",
|
85 |
+
param_dtype="torch.float16",
|
86 |
):
|
87 |
if device_type == "CPU":
|
88 |
+
device = "cpu"
|
89 |
+
param_dtype = "torch.float32"
|
90 |
else:
|
91 |
device = "cuda"
|
92 |
+
|
93 |
+
pipe.to(
|
94 |
+
torch_device=device,
|
95 |
+
torch_dtype=torch.float16 if param_dtype == "torch.float16" else torch.float32,
|
96 |
+
)
|
97 |
|
98 |
print(f"prompt: {prompt}")
|
99 |
print("sketch updated")
|
100 |
if image is None:
|
101 |
ones = Image.new("L", (512, 512), 255)
|
102 |
+
return ones
|
|
|
103 |
prompt = prompt_template.replace("{prompt}", prompt)
|
104 |
+
control_image = Image.fromarray(255 - np.array(image["composite"])[:, :, -1])
|
|
|
105 |
|
106 |
output_pil = pipe(
|
107 |
prompt=prompt,
|
|
|
115 |
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
|
116 |
).images[0]
|
117 |
|
118 |
+
return output_pil
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
+
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
gr.Markdown("# SDXS-512-DreamShaper-Sketch")
|
123 |
gr.Markdown("[SDXS: Real-Time One-Step Latent Diffusion Models with Image Conditions](https://arxiv.org/abs/2403.16627) | [GitHub](https://github.com/IDKiro/sdxs)")
|
|
|
|
|
|
|
124 |
with gr.Row(elem_id="main_row"):
|
125 |
with gr.Column(elem_id="column_input"):
|
126 |
gr.Markdown("## INPUT", elem_id="input_header")
|
127 |
+
image = gr.Sketchpad(
|
128 |
+
type="pil",
|
129 |
+
image_mode="RGBA",
|
130 |
+
brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=8),
|
131 |
+
crop_size=(512, 512),
|
132 |
+
)
|
133 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
# gr.Markdown("## Prompt", elem_id="tools_header")
|
135 |
prompt = gr.Textbox(label="Prompt", value="", show_label=True)
|
136 |
with gr.Row():
|
137 |
+
style = gr.Dropdown(
|
138 |
+
label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1
|
139 |
+
)
|
140 |
+
prompt_temp = gr.Textbox(
|
141 |
+
label="Prompt Style Template",
|
142 |
+
value=styles[DEFAULT_STYLE_NAME],
|
143 |
+
scale=2,
|
144 |
+
max_lines=1,
|
145 |
+
)
|
146 |
+
|
147 |
+
controlnet_conditioning_scale = gr.Slider(
|
148 |
+
label="Control Strength", minimum=0, maximum=1, step=0.01, value=0.8
|
149 |
+
)
|
150 |
+
|
151 |
+
device_choices = ["GPU", "CPU"]
|
152 |
+
device_type = gr.Radio(
|
153 |
+
device_choices,
|
154 |
+
label="Device",
|
155 |
+
value=device_choices[0],
|
156 |
+
interactive=True,
|
157 |
+
info="Many thanks to the community for the GPU!",
|
158 |
+
)
|
159 |
+
|
160 |
+
dtype_choices = ["torch.float16", "torch.float32"]
|
161 |
+
param_dtype = gr.Radio(
|
162 |
+
dtype_choices,
|
163 |
+
label="torch.weight_type",
|
164 |
+
value=dtype_choices[0],
|
165 |
+
interactive=True,
|
166 |
+
info="To save GPU memory, use torch.float16. For better quality, use torch.float32.",
|
167 |
+
)
|
168 |
|
169 |
with gr.Column(elem_id="column_output"):
|
170 |
gr.Markdown("## OUTPUT", elem_id="output_header")
|
171 |
+
result = gr.Image(
|
172 |
+
label="Result",
|
173 |
+
height=512,
|
174 |
+
width=512,
|
175 |
+
elem_id="output_image",
|
176 |
+
show_label=False,
|
177 |
+
show_download_button=True,
|
178 |
+
)
|
179 |
+
|
180 |
+
inputs = [
|
181 |
+
image,
|
182 |
+
prompt,
|
183 |
+
prompt_temp,
|
184 |
+
style,
|
185 |
+
controlnet_conditioning_scale,
|
186 |
+
device_type,
|
187 |
+
param_dtype,
|
188 |
+
]
|
189 |
+
outputs = [result]
|
190 |
+
|
191 |
+
prompt.change(fn=run, inputs=inputs, outputs=outputs)
|
192 |
style.change(lambda x: styles[x], inputs=[style], outputs=[prompt_temp]).then(
|
193 |
+
fn=run, inputs=inputs, outputs=outputs,)
|
|
|
194 |
image.change(run, inputs=inputs, outputs=outputs,)
|
195 |
+
controlnet_conditioning_scale.change(run, inputs=inputs, outputs=outputs,)
|
196 |
|
197 |
if __name__ == "__main__":
|
198 |
+
demo.queue().launch()
|