aifartist commited on
Commit
8541d70
·
verified ·
1 Parent(s): 53fe545

Update gradio-app.py

Browse files

Show the generation time.
Change resolution to 1024x1024

Files changed (1) hide show
  1. gradio-app.py +21 -12
gradio-app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  from PIL import Image
3
  import torch
@@ -6,6 +7,8 @@ import os
6
 
7
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
8
  TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
 
 
9
 
10
  if SAFETY_CHECKER:
11
  pipe = DiffusionPipeline.from_pretrained(
@@ -37,19 +40,24 @@ if TORCH_COMPILE:
37
 
38
  def predict(prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed=1231231):
39
  torch.manual_seed(seed)
 
40
  results = pipe(
41
  prompt1=prompt1,
42
  prompt2=prompt2,
43
  sv=merge_ratio,
44
  sharpness=sharpness,
45
- width=512,
46
- height=512,
47
  num_inference_steps=steps,
48
  guidance_scale=guidance,
49
  lcm_origin_steps=50,
50
  output_type="pil",
51
  # return_dict=False,
52
  )
 
 
 
 
53
  nsfw_content_detected = (
54
  results.nsfw_content_detected[0]
55
  if "nsfw_content_detected" in results
@@ -57,7 +65,7 @@ def predict(prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed=1231
57
  )
58
  if nsfw_content_detected:
59
  raise gr.Error("NSFW content detected. Please try another prompt.")
60
- return results.images[0]
61
 
62
 
63
  css = """
@@ -103,6 +111,7 @@ with gr.Blocks(css=css) as demo:
103
  )
104
  prompt1 = gr.Textbox(label="Prompt 1")
105
  prompt2 = gr.Textbox(label="Prompt 2")
 
106
  generate_bt = gr.Button("Generate")
107
 
108
  inputs = [prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed]
@@ -122,18 +131,18 @@ with gr.Blocks(css=css) as demo:
122
  ],
123
  fn=predict,
124
  inputs=inputs,
125
- outputs=image,
126
  )
127
- generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
128
- seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
129
  merge_ratio.change(
130
- fn=predict, inputs=inputs, outputs=image, show_progress=False
131
  )
132
- guidance.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
133
- steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
134
- sharpness.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
135
- prompt1.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
136
- prompt2.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
137
 
138
  demo.queue()
139
  if __name__ == "__main__":
 
1
+ import time
2
  import gradio as gr
3
  from PIL import Image
4
  import torch
 
7
 
8
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
9
  TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
10
+ # I have no idea where the env is config'ed. Thus:
11
+ TORCH_COMPILE = True
12
 
13
  if SAFETY_CHECKER:
14
  pipe = DiffusionPipeline.from_pretrained(
 
40
 
41
  def predict(prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed=1231231):
42
  torch.manual_seed(seed)
43
+ tm0 = time.time()
44
  results = pipe(
45
  prompt1=prompt1,
46
  prompt2=prompt2,
47
  sv=merge_ratio,
48
  sharpness=sharpness,
49
+ width=1024,
50
+ height=1024,
51
  num_inference_steps=steps,
52
  guidance_scale=guidance,
53
  lcm_origin_steps=50,
54
  output_type="pil",
55
  # return_dict=False,
56
  )
57
+ torch.cuda.synchronize()
58
+ tmval = f"time = {time.time()-tm0}"
59
+ print(f"time = {time.time()-tm0}")
60
+
61
  nsfw_content_detected = (
62
  results.nsfw_content_detected[0]
63
  if "nsfw_content_detected" in results
 
65
  )
66
  if nsfw_content_detected:
67
  raise gr.Error("NSFW content detected. Please try another prompt.")
68
+ return results.images[0], tmval
69
 
70
 
71
  css = """
 
111
  )
112
  prompt1 = gr.Textbox(label="Prompt 1")
113
  prompt2 = gr.Textbox(label="Prompt 2")
114
+ msg = gr.Textbox(label="Message", interactive=False)
115
  generate_bt = gr.Button("Generate")
116
 
117
  inputs = [prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed]
 
131
  ],
132
  fn=predict,
133
  inputs=inputs,
134
+ outputs=[image, 'XXX'],
135
  )
136
+ generate_bt.click(fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False)
137
+ seed.change(fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False)
138
  merge_ratio.change(
139
+ fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False
140
  )
141
+ guidance.change(fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False)
142
+ steps.change(fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False)
143
+ sharpness.change(fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False)
144
+ prompt1.change(fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False)
145
+ prompt2.change(fn=predict, inputs=inputs, outputs=[image, msg], show_progress=False)
146
 
147
  demo.queue()
148
  if __name__ == "__main__":