Jeff850's picture
Update app.py
67f0093 verified
raw
history blame
3.59 kB
import spaces
import os
import numpy as np
import gradio as gr
import json
import torch
from diffusers import DiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Use the 'waffles' environment variable as the access token
hf_token = os.getenv('waffles')
# Ensure the token is loaded correctly
if not hf_token:
raise ValueError("Hugging Face API token not found. Please set the 'waffles' environment variable.")
# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
loras = json.load(f)
# Initialize the base model with authentication and specify the device
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
token=hf_token
).to(device)
# Define MAX_SEED
MAX_SEED = 2**32 - 1
@spaces.GPU(duration=90)
def run_lora(prompt, cfg_scale, steps, selected_repo, randomize_seed, seed, width, height, lora_scale):
if not selected_repo:
raise gr.Error("You must select a LoRA before proceeding.")
selected_lora = next((lora for lora in loras if lora["repo"] == selected_repo), None)
if not selected_lora:
raise gr.Error("Selected LoRA not found.")
lora_path = selected_lora["repo"]
trigger_word = selected_lora["trigger_word"]
# Load LoRA weights
if "weights" in selected_lora:
pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
else:
pipe.load_lora_weights(lora_path)
# Set random seed for reproducibility
if randomize_seed:
seed = torch.randint(0, MAX_SEED, (1,)).item()
# Generate image
generator = torch.Generator(device=device).manual_seed(seed)
image = pipe(
prompt=f"{prompt} {trigger_word}",
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
).images[0]
# Reset the model to CPU and unload LoRA weights to free up memory
pipe.to("cpu")
pipe.unload_lora_weights()
return image, seed
with gr.Blocks() as app:
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(label="Prompt", lines=5, placeholder="Enter your prompt")
lora_dropdown = gr.Dropdown(
label="Select LoRA",
choices=[lora["repo"] for lora in loras],
value="XLabs-AI/flux-RealismLora",
)
with gr.Column(scale=1):
generate_button = gr.Button("Generate", variant="primary")
with gr.Row():
result = gr.Image(label="Generated Image")
seed = gr.Number(label="Seed", value=0, interactive=False)
with gr.Accordion("Advanced Settings", open=False):
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)
generate_button.click(
run_lora,
inputs=[prompt, cfg_scale, steps, lora_dropdown, randomize_seed, seed, width, height, lora_scale],
outputs=[result, seed]
)
app.launch()