Spaces:
Runtime error
Runtime error
Update gradio-app.py
Browse filesShow the generation time.
Change resolution to 1024x1024
- gradio-app.py +21 -12
gradio-app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
3 |
import torch
|
@@ -6,6 +7,8 @@ import os
|
|
6 |
|
7 |
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
|
8 |
TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
|
|
|
|
|
9 |
|
10 |
if SAFETY_CHECKER:
|
11 |
pipe = DiffusionPipeline.from_pretrained(
|
@@ -37,19 +40,24 @@ if TORCH_COMPILE:
|
|
37 |
|
38 |
def predict(prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed=1231231):
|
39 |
torch.manual_seed(seed)
|
|
|
40 |
results = pipe(
|
41 |
prompt1=prompt1,
|
42 |
prompt2=prompt2,
|
43 |
sv=merge_ratio,
|
44 |
sharpness=sharpness,
|
45 |
-
width=
|
46 |
-
height=
|
47 |
num_inference_steps=steps,
|
48 |
guidance_scale=guidance,
|
49 |
lcm_origin_steps=50,
|
50 |
output_type="pil",
|
51 |
# return_dict=False,
|
52 |
)
|
|
|
|
|
|
|
|
|
53 |
nsfw_content_detected = (
|
54 |
results.nsfw_content_detected[0]
|
55 |
if "nsfw_content_detected" in results
|
@@ -57,7 +65,7 @@ def predict(prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed=1231
|
|
57 |
)
|
58 |
if nsfw_content_detected:
|
59 |
raise gr.Error("NSFW content detected. Please try another prompt.")
|
60 |
-
return results.images[0]
|
61 |
|
62 |
|
63 |
css = """
|
@@ -103,6 +111,7 @@ with gr.Blocks(css=css) as demo:
|
|
103 |
)
|
104 |
prompt1 = gr.Textbox(label="Prompt 1")
|
105 |
prompt2 = gr.Textbox(label="Prompt 2")
|
|
|
106 |
generate_bt = gr.Button("Generate")
|
107 |
|
108 |
inputs = [prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed]
|
@@ -122,18 +131,18 @@ with gr.Blocks(css=css) as demo:
|
|
122 |
],
|
123 |
fn=predict,
|
124 |
inputs=inputs,
|
125 |
-
outputs=image,
|
126 |
)
|
127 |
-
generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
|
128 |
-
seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
|
129 |
merge_ratio.change(
|
130 |
-
fn=predict, inputs=inputs, outputs=image, show_progress=False
|
131 |
)
|
132 |
-
guidance.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
|
133 |
-
steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
|
134 |
-
sharpness.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
|
135 |
-
prompt1.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
|
136 |
-
prompt2.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
|
137 |
|
138 |
demo.queue()
|
139 |
if __name__ == "__main__":
|
|
|
1 |
+
import time
|
2 |
import gradio as gr
|
3 |
from PIL import Image
|
4 |
import torch
|
|
|
7 |
|
8 |
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
|
9 |
TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
|
10 |
+
# I have no idea where the env is config'ed. Thus:
|
11 |
+
TORCH_COMPILE = True
|
12 |
|
13 |
if SAFETY_CHECKER:
|
14 |
pipe = DiffusionPipeline.from_pretrained(
|
|
|
40 |
|
41 |
def predict(prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed=1231231):
|
42 |
torch.manual_seed(seed)
|
43 |
+
tm0 = time.time()
|
44 |
results = pipe(
|
45 |
prompt1=prompt1,
|
46 |
prompt2=prompt2,
|
47 |
sv=merge_ratio,
|
48 |
sharpness=sharpness,
|
49 |
+
width=1024,
|
50 |
+
height=1024,
|
51 |
num_inference_steps=steps,
|
52 |
guidance_scale=guidance,
|
53 |
lcm_origin_steps=50,
|
54 |
output_type="pil",
|
55 |
# return_dict=False,
|
56 |
)
|
57 |
+
torch.cuda.synchronize()
|
58 |
+
tmval = f"time = {time.time()-tm0}"
|
59 |
+
print(f"time = {time.time()-tm0}")
|
60 |
+
|
61 |
nsfw_content_detected = (
|
62 |
results.nsfw_content_detected[0]
|
63 |
if "nsfw_content_detected" in results
|
|
|
65 |
)
|
66 |
if nsfw_content_detected:
|
67 |
raise gr.Error("NSFW content detected. Please try another prompt.")
|
68 |
+
return results.images[0], tmval
|
69 |
|
70 |
|
71 |
css = """
|
|
|
111 |
)
|
112 |
prompt1 = gr.Textbox(label="Prompt 1")
|
113 |
prompt2 = gr.Textbox(label="Prompt 2")
|
114 |
+
msg = gr.Textbox(label="Message", interactive=False)
|
115 |
generate_bt = gr.Button("Generate")
|
116 |
|
117 |
inputs = [prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed]
|
|
|
131 |
],
|
132 |
fn=predict,
|
133 |
inputs=inputs,
|
134 |
+
outputs=[image, 'XXX'],
|
135 |
)
|
136 |
+
generate_bt.click(fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False)
|
137 |
+
seed.change(fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False)
|
138 |
merge_ratio.change(
|
139 |
+
fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False
|
140 |
)
|
141 |
+
guidance.change(fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False)
|
142 |
+
steps.change(fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False)
|
143 |
+
sharpness.change(fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False)
|
144 |
+
prompt1.change(fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False)
|
145 |
+
prompt2.change(fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False)
|
146 |
|
147 |
demo.queue()
|
148 |
if __name__ == "__main__":
|