APG checkbox
Browse files
app.py
CHANGED
@@ -12,6 +12,8 @@ from f_lite import FLitePipeline
|
|
12 |
|
13 |
# Trick required because it is not a native diffusers model
|
14 |
from diffusers.pipelines.pipeline_loading_utils import LOADABLE_CLASSES, ALL_IMPORTABLE_CLASSES
|
|
|
|
|
15 |
LOADABLE_CLASSES["f_lite"] = LOADABLE_CLASSES["f_lite.model"] = {"DiT": ["save_pretrained", "from_pretrained"]}
|
16 |
ALL_IMPORTABLE_CLASSES["DiT"] = ["save_pretrained", "from_pretrained"]
|
17 |
|
@@ -26,7 +28,7 @@ else:
|
|
26 |
logging.warning("GEMINI_API_KEY not found in environment variables. Prompt enrichment will not work.")
|
27 |
|
28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
-
model_repo_id = "
|
30 |
|
31 |
if torch.cuda.is_available():
|
32 |
torch_dtype = torch.bfloat16
|
@@ -135,6 +137,7 @@ def infer(
|
|
135 |
guidance_scale,
|
136 |
num_inference_steps,
|
137 |
use_prompt_enrichment,
|
|
|
138 |
progress=gr.Progress(track_tqdm=True),
|
139 |
):
|
140 |
enriched_prompt_str = None
|
@@ -160,6 +163,7 @@ def infer(
|
|
160 |
width=width,
|
161 |
height=height,
|
162 |
generator=generator,
|
|
|
163 |
).images[0]
|
164 |
|
165 |
# Prepare Gradio updates for the enriched prompt display
|
@@ -287,6 +291,10 @@ with gr.Blocks(css=css, theme="ParityError/Interstellar") as demo:
|
|
287 |
step=0.1,
|
288 |
value=6,
|
289 |
)
|
|
|
|
|
|
|
|
|
290 |
|
291 |
num_inference_steps = gr.Slider(
|
292 |
label="Number of inference steps",
|
@@ -334,6 +342,7 @@ with gr.Blocks(css=css, theme="ParityError/Interstellar") as demo:
|
|
334 |
guidance_scale,
|
335 |
num_inference_steps,
|
336 |
use_prompt_enrichment,
|
|
|
337 |
],
|
338 |
outputs=[result, seed, enriched_prompt_display, enriched_prompt_text, enrichment_error],
|
339 |
)
|
@@ -342,4 +351,4 @@ with gr.Blocks(css=css, theme="ParityError/Interstellar") as demo:
|
|
342 |
gr.Markdown("[F-Lite Model Card and Weights](https://huggingface.co/Freepik/F-Lite)")
|
343 |
|
344 |
if __name__ == "__main__":
|
345 |
-
demo.launch()
|
|
|
12 |
|
13 |
# Trick required because it is not a native diffusers model
|
14 |
from diffusers.pipelines.pipeline_loading_utils import LOADABLE_CLASSES, ALL_IMPORTABLE_CLASSES
|
15 |
+
|
16 |
+
from f_lite.pipeline import APGConfig
|
17 |
LOADABLE_CLASSES["f_lite"] = LOADABLE_CLASSES["f_lite.model"] = {"DiT": ["save_pretrained", "from_pretrained"]}
|
18 |
ALL_IMPORTABLE_CLASSES["DiT"] = ["save_pretrained", "from_pretrained"]
|
19 |
|
|
|
28 |
logging.warning("GEMINI_API_KEY not found in environment variables. Prompt enrichment will not work.")
|
29 |
|
30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
+
model_repo_id = "./grpo_hf"
|
32 |
|
33 |
if torch.cuda.is_available():
|
34 |
torch_dtype = torch.bfloat16
|
|
|
137 |
guidance_scale,
|
138 |
num_inference_steps,
|
139 |
use_prompt_enrichment,
|
140 |
+
enable_apg,
|
141 |
progress=gr.Progress(track_tqdm=True),
|
142 |
):
|
143 |
enriched_prompt_str = None
|
|
|
163 |
width=width,
|
164 |
height=height,
|
165 |
generator=generator,
|
166 |
+
apg_config=APGConfig(enabled=enable_apg)
|
167 |
).images[0]
|
168 |
|
169 |
# Prepare Gradio updates for the enriched prompt display
|
|
|
291 |
step=0.1,
|
292 |
value=6,
|
293 |
)
|
294 |
+
enable_apg = gr.Checkbox(
|
295 |
+
label="Enable APG",
|
296 |
+
value=True,
|
297 |
+
)
|
298 |
|
299 |
num_inference_steps = gr.Slider(
|
300 |
label="Number of inference steps",
|
|
|
342 |
guidance_scale,
|
343 |
num_inference_steps,
|
344 |
use_prompt_enrichment,
|
345 |
+
enable_apg,
|
346 |
],
|
347 |
outputs=[result, seed, enriched_prompt_display, enriched_prompt_text, enrichment_error],
|
348 |
)
|
|
|
351 |
gr.Markdown("[F-Lite Model Card and Weights](https://huggingface.co/Freepik/F-Lite)")
|
352 |
|
353 |
if __name__ == "__main__":
|
354 |
+
demo.launch() # server_name="0.0.0.0", share=True)
|