import gradio as gr import numpy as np import random import spaces from pipeline_flux import FluxPipeline from transformer_flux import FluxTransformer2DModel import torch flux_model = "schnell" bfl_repo = f"black-forest-labs/FLUX.1-{flux_model}" device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 transformer = FluxTransformer2DModel.from_pretrained( bfl_repo, subfolder="transformer", torch_dtype=dtype ) pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, torch_dtype=dtype) pipe.transformer = transformer pipe.scheduler.config.use_dynamic_shifting = False pipe.scheduler.config.time_shift = 10 # pipe.enable_model_cpu_offload() pipe = pipe.to(device) pipe.load_lora_weights( "Huage001/URAE", weight_name="urae_2k_adapter.safetensors", adapter_name="2k", ) pipe.load_lora_weights( "Huage001/URAE", weight_name="urae_4k_adapter_lora_conversion_dev.safetensors", adapter_name="4k_dev", ) pipe.load_lora_weights( "Huage001/URAE", weight_name="urae_4k_adapter_lora_conversion_schnell.safetensors", adapter_name="4k_schnell", ) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 4096 USE_ZERO_GPU = True # @spaces.GPU #[uncomment to use ZeroGPU] def infer( prompt, seed, randomize_seed, width, height, num_inference_steps, model='2k', progress=gr.Progress(track_tqdm=True), ): print("Using model:", model) if model == "2k": pipe.vae.enable_tiling(True) pipe.set_adapters("2k") elif model == "4k": pipe.vae.enable_tiling(True) pipe.set_adapters(f"4k_{flux_model}") if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) image = pipe( prompt=prompt, guidance_scale=0, num_inference_steps=num_inference_steps, width=width, height=height, max_sequence_length=256, ntk_factor=10, proportional_attention=True, generator=generator, ).images[0] return image, seed if USE_ZERO_GPU: infer = spaces.GPU(infer, duration=360) examples = [ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "An astronaut riding a green horse", "A delicious ceviche cheesecake slice", ] css = """ #maincontainer { display: flex; } #col1 { margin: 0 auto; max-width: 50%; } #col2 { margin: 0 auto; # max-width: 40px; } """ head = """> ***U*ltra-*R*esolution *A*daptation with *E*ase** <div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> <a href="https://arxiv.org/abs/2503.16322"><img src="https://img.shields.io/badge/arXiv-2503.16322-A42C25.svg" alt="arXiv"></a> <a href="https://huggingface.co/Huage001/URAE"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a> <a href="https://huggingface.co/spaces/Yuanshi/URAE"><img src="https://img.shields.io/badge/🤗_HuggingFace-Space-ffbd45.svg" alt="HuggingFace"></a> <a href="https://huggingface.co/spaces/Yuanshi/URAE_dev"><img src="https://img.shields.io/badge/🤗_HuggingFace-Space-ffbd45.svg" alt="HuggingFace"></a> </div> """ with gr.Blocks(css=css) as demo: gr.Markdown("# URAE (FLUX.1 schnell) \n" + head) with gr.Row(elem_id="maincontainer"): with gr.Column(elem_id="col1"): gr.Markdown("### Prompt:") prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) gr.Examples(examples=examples, inputs=[prompt]) run_button = gr.Button("Generate", scale=1, variant="primary") gr.Markdown("### Setting:") # model = gr.Radio( # label="Model", # choices=[ # ("2K model", "2k"), # ("4K model (beta)", "4k"), # ], # value="2k", # ) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=2048, # Replace with defaults that work for your model ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=2048, # Replace with defaults that work for your model ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=4, # Replace with defaults that work for your model ) with gr.Column(elem_id="col2"): result = gr.Image(label="Result", show_label=False) gr.on( triggers=[run_button.click, prompt.submit], fn=infer, inputs=[ prompt, # model, seed, randomize_seed, width, height, num_inference_steps, ], outputs=[result, seed], ) if __name__ == "__main__": demo.launch()