ideprado commited on
Commit
fc8839e
·
1 Parent(s): 9cbbe55

APG checkbox

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -12,6 +12,8 @@ from f_lite import FLitePipeline
12
 
13
  # Trick required because it is not a native diffusers model
14
  from diffusers.pipelines.pipeline_loading_utils import LOADABLE_CLASSES, ALL_IMPORTABLE_CLASSES
 
 
15
  LOADABLE_CLASSES["f_lite"] = LOADABLE_CLASSES["f_lite.model"] = {"DiT": ["save_pretrained", "from_pretrained"]}
16
  ALL_IMPORTABLE_CLASSES["DiT"] = ["save_pretrained", "from_pretrained"]
17
 
@@ -26,7 +28,7 @@ else:
26
  logging.warning("GEMINI_API_KEY not found in environment variables. Prompt enrichment will not work.")
27
 
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
- model_repo_id = "Freepik/F-Lite"
30
 
31
  if torch.cuda.is_available():
32
  torch_dtype = torch.bfloat16
@@ -135,6 +137,7 @@ def infer(
135
  guidance_scale,
136
  num_inference_steps,
137
  use_prompt_enrichment,
 
138
  progress=gr.Progress(track_tqdm=True),
139
  ):
140
  enriched_prompt_str = None
@@ -160,6 +163,7 @@ def infer(
160
  width=width,
161
  height=height,
162
  generator=generator,
 
163
  ).images[0]
164
 
165
  # Prepare Gradio updates for the enriched prompt display
@@ -287,6 +291,10 @@ with gr.Blocks(css=css, theme="ParityError/Interstellar") as demo:
287
  step=0.1,
288
  value=6,
289
  )
 
 
 
 
290
 
291
  num_inference_steps = gr.Slider(
292
  label="Number of inference steps",
@@ -334,6 +342,7 @@ with gr.Blocks(css=css, theme="ParityError/Interstellar") as demo:
334
  guidance_scale,
335
  num_inference_steps,
336
  use_prompt_enrichment,
 
337
  ],
338
  outputs=[result, seed, enriched_prompt_display, enriched_prompt_text, enrichment_error],
339
  )
@@ -342,4 +351,4 @@ with gr.Blocks(css=css, theme="ParityError/Interstellar") as demo:
342
  gr.Markdown("[F-Lite Model Card and Weights](https://huggingface.co/Freepik/F-Lite)")
343
 
344
  if __name__ == "__main__":
345
- demo.launch()
 
12
 
13
  # Trick required because it is not a native diffusers model
14
  from diffusers.pipelines.pipeline_loading_utils import LOADABLE_CLASSES, ALL_IMPORTABLE_CLASSES
15
+
16
+ from f_lite.pipeline import APGConfig
17
  LOADABLE_CLASSES["f_lite"] = LOADABLE_CLASSES["f_lite.model"] = {"DiT": ["save_pretrained", "from_pretrained"]}
18
  ALL_IMPORTABLE_CLASSES["DiT"] = ["save_pretrained", "from_pretrained"]
19
 
 
28
  logging.warning("GEMINI_API_KEY not found in environment variables. Prompt enrichment will not work.")
29
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ model_repo_id = "./grpo_hf"
32
 
33
  if torch.cuda.is_available():
34
  torch_dtype = torch.bfloat16
 
137
  guidance_scale,
138
  num_inference_steps,
139
  use_prompt_enrichment,
140
+ enable_apg,
141
  progress=gr.Progress(track_tqdm=True),
142
  ):
143
  enriched_prompt_str = None
 
163
  width=width,
164
  height=height,
165
  generator=generator,
166
+ apg_config=APGConfig(enabled=enable_apg)
167
  ).images[0]
168
 
169
  # Prepare Gradio updates for the enriched prompt display
 
291
  step=0.1,
292
  value=6,
293
  )
294
+ enable_apg = gr.Checkbox(
295
+ label="Enable APG",
296
+ value=True,
297
+ )
298
 
299
  num_inference_steps = gr.Slider(
300
  label="Number of inference steps",
 
342
  guidance_scale,
343
  num_inference_steps,
344
  use_prompt_enrichment,
345
+ enable_apg,
346
  ],
347
  outputs=[result, seed, enriched_prompt_display, enriched_prompt_text, enrichment_error],
348
  )
 
351
  gr.Markdown("[F-Lite Model Card and Weights](https://huggingface.co/Freepik/F-Lite)")
352
 
353
  if __name__ == "__main__":
354
+ demo.launch() # server_name="0.0.0.0", share=True)