soonfactory4 / app.py
AlekseyCalvin's picture
Update app.py
5741f84 verified
raw
history blame
8.23 kB
import gradio as gr
import json
import logging
import torch
from PIL import Image
import spaces
from diffusers import DiffusionPipeline
import copy
import random
import time
from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
from accelerate import init_empty_weights
from convert_nf4_flux import replace_with_bnb_linear, create_quantized_param, check_quantized_param
from diffusers import FluxTransformer2DModel, FluxPipeline
import safetensors.torch
import gc
import torch
# Set dtype and check for float8 support
dtype = torch.bfloat16
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
ckpt_path = hf_hub_download("ABDALLALSWAITI/Maxwell", filename="diffusion_pytorch_model.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)
with init_empty_weights():
config = FluxTransformer2DModel.load_config("ABDALLALSWAITI/Maxwell")
model = FluxTransformer2DModel.from_config(config).to(dtype)
expected_state_dict_keys = list(model.state_dict().keys())
# Load the state dict into the quantized model
for param_name, param in original_state_dict.items():
if param_name not in expected_state_dict_keys:
continue
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
param = param.to(dtype)
if not check_quantized_param(model, param_name):
set_module_tensor_to_device(model, param_name, device=0, value=param)
else:
create_quantized_param(
model, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=True
)
# Clean up
del original_state_dict
gc.collect()
# Print model size
print(compute_module_sizes(model)[""] / 1024 / 1204)
pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
loras = json.load(f)
MAX_SEED = 2**32-1
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
def update_selection(evt: gr.SelectData, width, height):
selected_lora = loras[evt.index]
new_placeholder = f"Type a prompt for {selected_lora['title']}"
lora_repo = selected_lora["repo"]
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
if "aspect" in selected_lora:
if selected_lora["aspect"] == "portrait":
width = 768
height = 1024
elif selected_lora["aspect"] == "landscape":
width = 1024
height = 768
return (
gr.update(placeholder=new_placeholder),
updated_text,
evt.index,
width,
height,
)
@spaces.GPU(duration=70)
def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress):
pipe.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(seed)
with calculateDuration("Generating image"):
# Generate image
image = pipe(
prompt=f"{prompt} {trigger_word}",
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
).images[0]
return image
def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
if selected_index is None:
raise gr.Error("You must select a LoRA before proceeding.")
selected_lora = loras[selected_index]
lora_path = selected_lora["repo"]
trigger_word = selected_lora["trigger_word"]
# Load LoRA weights
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
if "weights" in selected_lora:
pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
else:
pipe.load_lora_weights(lora_path)
# Set random seed for reproducibility
with calculateDuration("Randomizing seed"):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress)
pipe.to("cpu")
pipe.unload_lora_weights()
return image, seed
run_lora.zerogpu = True
css = '''
#gen_btn{height: 100%}
#title{text-align: center}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
#gallery .grid-wrap{height: 10vh}
'''
with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
title = gr.HTML(
"""<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> SOONfactory </h1>""",
elem_id="title",
)
# Info blob stating what the app is running
info_blob = gr.HTML(
"""<div id="info_blob"> Activist & Futurealist LoRa-stocked Img Manufactory (on Flux Merged)</div>"""
)
# Info blob stating what the app is running
info_blob = gr.HTML(
"""<div id="info_blob">Prephrase prompts w/: 1.RCA style 2. HST style autochrome 3.TOK hybrid 4.2004 photo 5.HST style 6.LEN Vladimir Lenin 7.TOK portra 8.HST portrait 9.flmft 10.HST in Peterhof 11.photo 12.pficonics 13.wh3r3sw4ld0 14.retrofuturism 15.vintage cover </div>"""
)
selected_index = gr.State(None)
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Select LoRa/Style & type prompt!")
with gr.Column(scale=1, elem_id="gen_column"):
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
with gr.Row():
with gr.Column(scale=3):
selected_info = gr.Markdown("")
gallery = gr.Gallery(
[(item["image"], item["title"]) for item in loras],
label="LoRA Inventory",
allow_preview=False,
columns=3,
elem_id="gallery"
)
with gr.Column(scale=4):
result = gr.Image(label="Generated Image")
with gr.Row():
with gr.Accordion("Advanced Settings", open=True):
with gr.Column():
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=1, value=3)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=6)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=768)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=768)
with gr.Row():
randomize_seed = gr.Checkbox(True, label="Randomize seed")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)
gallery.select(
update_selection,
inputs=[width, height],
outputs=[prompt, selected_info, selected_index, width, height]
)
gr.on(
triggers=[generate_button.click, prompt.submit],
fn=run_lora,
inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
outputs=[result, seed]
)
app.queue(default_concurrency_limit=2).launch(show_error=True)
app.launch()