from daam import trace, set_seed from diffusers import DiffusionPipeline from matplotlib import pyplot as plt import torch import os # Verify GPU availability if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available. Please ensure a GPU is available and PyTorch is installed with CUDA support.") # Create output directory output_dir = 'sdxl' os.makedirs(output_dir, exist_ok=True) # Create 'sdxl' folder if it doesn't exist # Model setup model_id = 'stabilityai/stable-diffusion-xl-base-1.0' device = 'cuda' # Explicitly set to GPU # Load the pipeline with float16 for GPU pipe = DiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16, # Use float16 for faster inference and lower memory usage on GPU use_safetensors=True, # Safetensors for faster loading variant='fp16' # FP16 variant for efficiency ) # GPU-specific optimizations pipe.enable_model_cpu_offload() # Offload parts to CPU if VRAM is low pipe.enable_vae_slicing() # Slice VAE operations to reduce memory usage pipe = pipe.to(device) # Prompt and generation settings prompt = 'A human holding his hand up' gen = set_seed(42) # Reproducible seed # Generate image and heatmaps with torch.no_grad(): with trace(pipe) as tc: out = pipe( prompt, num_inference_steps=15, # Reduced steps for faster generation (increase to 30-50 for better quality) generator=gen, callback=tc.time_callback, callback_steps=1 ) # Save the generated image generated_image_path = os.path.join(output_dir, 'generated_image.png') out.images[0].save(generated_image_path) # Generate and save heatmaps heat_map = tc.compute_global_heat_map() for word in prompt.split(): word_heat_map = heat_map.compute_word_heat_map(word) # Create the heatmap overlay plot fig = plt.figure() word_heat_map.plot_overlay(out.images[0]) plt.title(f"Heatmap for '{word}'") # Save the heatmap as a PNG heatmap_path = os.path.join(output_dir, f'heatmap_{word}.png') plt.savefig(heatmap_path, bbox_inches='tight') plt.close(fig) # Close the figure to free memory # Save the experiment exp = tc.to_experiment('sdxl-cat-experiment-gpu') exp.save() # Saves to 'sdxl-cat-experiment-gpu' folder print(f"Generation complete! Images saved in '{output_dir}' folder:") print(f"- Generated image: {generated_image_path}") print(f"- Heatmaps: {output_dir}/heatmap_.png") print("Experiment saved in 'sdxl-cat-experiment-gpu'.")