Jeff850's picture
Update app.py
0711ded verified
raw
history blame
5.26 kB
import spaces
import os
import numpy as np
import gradio as gr
import json
import torch
from diffusers import DiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Use the 'waffles' environment variable as the access token
hf_token = os.getenv('waffles')
# Ensure the token is loaded correctly
if not hf_token:
raise ValueError("Hugging Face API token not found. Please set the 'waffles' environment variable.")
# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
loras = json.load(f)
# Initialize the base model with authentication and specify the device
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
token=hf_token
).to(device)
# Define MAX_SEED
MAX_SEED = 2**32 - 1
@spaces.GPU(duration=90)
def run_lora(prompt, cfg_scale, steps, selected_repo, randomize_seed, seed, width, height, lora_scale):
if not selected_repo:
raise gr.Error("You must select a LoRA before proceeding.")
selected_lora = next((lora for lora in loras if lora["repo"] == selected_repo), None)
if not selected_lora:
raise gr.Error("Selected LoRA not found.")
lora_path = selected_lora["repo"]
trigger_word = selected_lora["trigger_word"]
# Load LoRA weights
if "weights" in selected_lora:
pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
else:
pipe.load_lora_weights(lora_path)
# Set random seed for reproducibility
if randomize_seed:
seed = torch.randint(0, MAX_SEED, (1,)).item()
# Generate image
generator = torch.Generator(device=device).manual_seed(seed)
image = pipe(
prompt=f"{prompt} {trigger_word}",
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
).images[0]
# Reset the model to CPU and unload LoRA weights to free up memory
pipe.to("cpu")
pipe.unload_lora_weights()
return image, seed
# Custom CSS for GUI styling
css = '''
#gen_btn{height: 100%}
#title{text-align: center;}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
'''
def update_selection(index, width, height):
selected_lora = loras[index]
new_placeholder = f"Type a prompt for {selected_lora['title']}"
lora_repo = selected_lora["repo"]
updated_text = f"### Selected: [{lora_repo}]"
if "aspect" in selected_lora:
if selected_lora["aspect"] == "portrait":
width = 768
height = 1024
elif selected_lora["aspect"] == "landscape":
width = 1024
height = 768
return gr.update(placeholder=new_placeholder), updated_text, width, height
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as app:
title = gr.HTML(
"""<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> FLUX LoRA the Explorer</h1>""",
elem_id="title",
)
selected_index = gr.State(None)
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
with gr.Column(scale=1, elem_id="gen_column"):
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
with gr.Row():
with gr.Column(scale=3):
selected_info = gr.Markdown("")
gallery = gr.Gallery(
[(item["image"], item["title"]) for item in loras],
label="LoRA Gallery",
allow_preview=False,
columns=3
)
with gr.Column(scale=4):
result = gr.Image(label="Generated Image")
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
with gr.Column():
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
with gr.Row():
randomize_seed = gr.Checkbox(True, label="Randomize seed")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)
gallery.select(
update_selection,
inputs=[width, height],
outputs=[prompt, selected_info, width, height]
)
generate_button.click(
fn=run_lora,
inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
outputs=[result, seed]
)
app.queue()
app.launch()