MegaTronX's picture
Update app.py
a6cafd5 verified
import spaces
import gradio as gr
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from PIL import Image
import numpy as np
import random
# Only initialize GPU after spaces import
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Constants
#BASE_MODEL = "black-forest-labs/FLUX.1-dev"
#LORA_MODEL = "MegaTronX/SuicideGirl-FLUX" # Replace with your LoRA path
MAX_SEED = np.iinfo(np.int32).max
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights("MegaTronX/SuicideGirl-FLUX", weight_name="SuicideGirls.safetensors")
pipe.fuse_lora(lora_scale=0.8)
pipe.to("cuda")
# Initialize model and scheduler
'''if torch.cuda.is_available():
transformer = FluxTransformer2DModel.from_single_file(
"https://huggingface.co/MegaTronX/SuicideGirl-FLUX/blob/main/SuicideGirls.safetensors",
torch_dtype=torch.bfloat16
)
pipe = FluxPipeline.from_pretrained(
BASE_MODEL,
transformer=transformer,
torch_dtype=torch.bfloat16
)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
pipe.scheduler.config, use_beta_sigmas=True
)
pipe.to("cuda")
# Load and apply LoRA weights
pipe.load_lora_weights(LORA_MODEL)
'''
@spaces.GPU
def generate_image(
prompt,
width=768,
height=1024,
guidance_scale=3.5,
num_inference_steps=24,
seed=-1,
num_images=1,
progress=gr.Progress(track_tqdm=True)
):
if seed == -1:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
images = pipe(
prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
output_type="pil",
max_sequence_length=512,
num_images_per_prompt=num_images,
).images
return images, seed
# Gradio Interface
with gr.Blocks() as demo:
gr.HTML("<h1><center>Flux LoRA Image Generator</center></h1>")
with gr.Group():
prompt = gr.Textbox(label='Enter Your Prompt', lines=3)
generate_button = gr.Button("Generate", variant='primary')
with gr.Row():
image_output = gr.Gallery(label="Generated Images", columns=2, preview=True)
seed_output = gr.Number(label="Seed Used")
with gr.Accordion("Advanced Options", open=False):
width = gr.Slider(label="Width", minimum=512, maximum=1280, step=8, value=768)
height = gr.Slider(label="Height", minimum=512, maximum=1280, step=8, value=1024)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0, maximum=50, step=0.1, value=3.5)
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=24)
seed = gr.Slider(label="Seed (-1 for random)", minimum=-1, maximum=MAX_SEED, step=1, value=-1)
num_images = gr.Slider(label="Number of Images", minimum=1, maximum=4, step=1, value=1)
generate_button.click(
fn=generate_image,
inputs=[prompt, width, height, guidance_scale, num_inference_steps, seed, num_images],
outputs=[image_output, seed_output]
)
demo.launch()