radames commited on
Commit
84519e8
·
1 Parent(s): 931cf15

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +16 -6
  2. lcm_txt2img/pipeline.py +1 -1
app.py CHANGED
@@ -37,7 +37,7 @@ if TORCH_COMPILE:
37
 
38
  def predict(prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed=1231231):
39
  torch.manual_seed(seed)
40
- img = pipe(
41
  prompt1=prompt1,
42
  prompt2=prompt2,
43
  sv=merge_ratio,
@@ -48,12 +48,19 @@ def predict(prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed=1231
48
  guidance_scale=guidance,
49
  lcm_origin_steps=50,
50
  output_type="pil",
51
- return_dict=False,
52
  )
53
- return img
 
 
 
 
 
 
 
54
 
55
 
56
- css="""
57
  #container{
58
  margin: 0 auto;
59
  max-width: 80rem;
@@ -74,7 +81,8 @@ with gr.Blocks(css=css) as demo:
74
  RTSD leverages the expertise provided by Latent Consistency Models (LCM). For more information about LCM,
75
  visit their website at [Latent Consistency Models](https://latent-consistency-models.github.io/).
76
 
77
- """, elem_id="intro"
 
78
  )
79
  with gr.Row():
80
  with gr.Column():
@@ -90,7 +98,9 @@ with gr.Blocks(css=css) as demo:
90
  sharpness = gr.Slider(
91
  value=1.0, minimum=0, maximum=1, step=0.001, label="Sharpness"
92
  )
93
- seed = gr.Slider(randomize=True, minimum=0, maximum=12013012031030, label="Seed")
 
 
94
  prompt1 = gr.Textbox(label="Prompt 1")
95
  prompt2 = gr.Textbox(label="Prompt 2")
96
  generate_bt = gr.Button("Generate")
 
37
 
38
  def predict(prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed=1231231):
39
  torch.manual_seed(seed)
40
+ results = pipe(
41
  prompt1=prompt1,
42
  prompt2=prompt2,
43
  sv=merge_ratio,
 
48
  guidance_scale=guidance,
49
  lcm_origin_steps=50,
50
  output_type="pil",
51
+ # return_dict=False,
52
  )
53
+ nsfw_content_detected = (
54
+ results.nsfw_content_detected[0]
55
+ if "nsfw_content_detected" in results
56
+ else False
57
+ )
58
+ if nsfw_content_detected:
59
+ raise gr.Error("NSFW content detected. Please try another prompt.")
60
+ return results.images[0]
61
 
62
 
63
+ css = """
64
  #container{
65
  margin: 0 auto;
66
  max-width: 80rem;
 
81
  RTSD leverages the expertise provided by Latent Consistency Models (LCM). For more information about LCM,
82
  visit their website at [Latent Consistency Models](https://latent-consistency-models.github.io/).
83
 
84
+ """,
85
+ elem_id="intro",
86
  )
87
  with gr.Row():
88
  with gr.Column():
 
98
  sharpness = gr.Slider(
99
  value=1.0, minimum=0, maximum=1, step=0.001, label="Sharpness"
100
  )
101
+ seed = gr.Slider(
102
+ randomize=True, minimum=0, maximum=12013012031030, label="Seed"
103
+ )
104
  prompt1 = gr.Textbox(label="Prompt 1")
105
  prompt2 = gr.Textbox(label="Prompt 2")
106
  generate_bt = gr.Button("Generate")
lcm_txt2img/pipeline.py CHANGED
@@ -308,7 +308,7 @@ class LatentConsistencyModelPipeline(DiffusionPipeline):
308
  #denoised = denoised.to(prompt_embeds.dtype)
309
  if not output_type == "latent":
310
  image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
311
- #image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
312
  has_nsfw_concept = None
313
  else:
314
  image = denoised
 
308
  #denoised = denoised.to(prompt_embeds.dtype)
309
  if not output_type == "latent":
310
  image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
311
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
312
  has_nsfw_concept = None
313
  else:
314
  image = denoised