Daam2.0 / Scripts /run_sdxl_creaprompt.py
ApacheOne's picture
Upload 4 files
355af0a verified
from daam import trace, set_seed
from diffusers import StableDiffusionXLPipeline
from matplotlib import pyplot as plt
import torch
import os
# Verify GPU availability
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. Please ensure a GPU is available and PyTorch is installed with CUDA support.")
# Create output directory
output_dir = 'sdxl-creaprompt'
os.makedirs(output_dir, exist_ok=True) # Create 'sdxl-creaprompt' folder if it doesn't exist
# Model setup
model_url = 'https://huggingface.co/ApacheOne/local-checkpoints/blob/main/SDXL(PONY)/creapromptLightning_creapromtHypersdxlV1.safetensors'
device = 'cuda' # Explicitly set to GPU
# Load the pipeline from a single .safetensors file
pipe = StableDiffusionXLPipeline.from_single_file(
model_url,
torch_dtype=torch.float16, # Use float16 for faster inference on GPU
use_safetensors=True, # Ensure safetensors format
variant='fp16' # FP16 variant for efficiency
)
# GPU-specific optimizations
pipe.enable_model_cpu_offload() # Offload parts to CPU if VRAM is low
pipe.enable_vae_slicing() # Slice VAE operations to reduce memory usage
pipe = pipe.to(device)
# Prompt and generation settings
prompt = 'realism eohwx woman, wearing dark black low wasit jeans,white shoes and red crop top, hands by side, ,full body shot,Lake Tahoe,(masterpiece best quality ultra-detailed best shadow amazing realistic picture)'
gen = set_seed(42) # Reproducible seed
# Generate image and heatmaps
with torch.no_grad():
with trace(pipe) as tc:
out = pipe(
prompt,
num_inference_steps=6, # Reduced steps for faster generation (increase to 30-50 for better quality)
generator=gen,
callback=tc.time_callback,
callback_steps=1
)
# Save the generated image
generated_image_path = os.path.join(output_dir, 'generated_image.png')
out.images[0].save(generated_image_path)
# Generate and save heatmaps
heat_map = tc.compute_global_heat_map()
for word in prompt.split():
word_heat_map = heat_map.compute_word_heat_map(word)
# Create the heatmap overlay plot
fig = plt.figure()
word_heat_map.plot_overlay(out.images[0])
plt.title(f"Heatmap for '{word}'")
# Save the heatmap as a PNG
heatmap_path = os.path.join(output_dir, f'heatmap_{word}.png')
plt.savefig(heatmap_path, bbox_inches='tight')
plt.close(fig) # Close the figure to free memory
# Save the experiment
exp = tc.to_experiment('sdxl-creaprompt-experiment-gpu')
exp.save() # Saves to 'sdxl-creaprompt-experiment-gpu' folder
print(f"Generation complete! Images saved in '{output_dir}' folder:")
print(f"- Generated image: {generated_image_path}")
print(f"- Heatmaps: {output_dir}/heatmap_<word>.png")
print("Experiment saved in 'sdxl-creaprompt-experiment-gpu'.")