import spaces import gradio as gr import random import os import time import torch from diffusers import FluxPipeline DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {DEVICE}") DEFAULT_HEIGHT = 1024 DEFAULT_WIDTH = 1024 DEFAULT_GUIDANCE_SCALE = 3.5 DEFAULT_NUM_INFERENCE_STEPS = 15 DEFAULT_MAX_SEQUENCE_LENGTH = 512 HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN") # Cache for the pipeline CACHED_PIPE = None def load_bnb_4bit_pipeline(): """Load the 4-bit quantized pipeline""" global CACHED_PIPE if CACHED_PIPE is not None: return CACHED_PIPE print("Loading 4-bit BNB pipeline...") MODEL_ID = "derekl35/FLUX.1-dev-nf4" start_time = time.time() try: pipe = FluxPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16 ) pipe.enable_model_cpu_offload() end_time = time.time() mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 print(f"4-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB") CACHED_PIPE = pipe return pipe except Exception as e: print(f"Error loading 4-bit BNB pipeline: {e}") raise @spaces.GPU(duration=240) def generate_image(prompt, progress=gr.Progress(track_tqdm=True)): """Generate image using 4-bit quantized model""" if not prompt: return None, "Please enter a prompt." progress(0.2, desc="Loading 4-bit quantized model...") try: # Load the 4-bit pipeline pipe = load_bnb_4bit_pipeline() # Set up generation parameters pipe_kwargs = { "prompt": prompt, "height": DEFAULT_HEIGHT, "width": DEFAULT_WIDTH, "guidance_scale": DEFAULT_GUIDANCE_SCALE, "num_inference_steps": DEFAULT_NUM_INFERENCE_STEPS, "max_sequence_length": DEFAULT_MAX_SEQUENCE_LENGTH, } # Generate seed seed = random.getrandbits(64) print(f"Using seed: {seed}") progress(0.5, desc="Generating image...") # Generate image gen_start_time = time.time() image = pipe(**pipe_kwargs, generator=torch.manual_seed(seed)).images[0] gen_end_time = time.time() print(f"Image generated in {gen_end_time - gen_start_time:.2f} seconds") mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 print(f"Memory reserved: {mem_reserved:.2f} GB") return image, f"Generation complete! (Seed: {seed})" except Exception as e: print(f"Error during generation: {e}") return None, f"Error: {e}" # Create Gradio interface with gr.Blocks(title="FLUXllama", theme=gr.themes.Soft()) as demo: gr.HTML( """
FLUX.1-dev 4-bit Quantized Version