|
from daam import trace, set_seed
|
|
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler, AutoencoderKL
|
|
from matplotlib import pyplot as plt
|
|
import torch
|
|
import os
|
|
|
|
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError("CUDA is not available. Please ensure a GPU is available and PyTorch is installed with CUDA support.")
|
|
|
|
|
|
output_dir = 'sdxl-creaprompt'
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
model_url = 'https://huggingface.co/ApacheOne/local-checkpoints/blob/main/SDXL(PONY)/creapromptLightning_creapromtHypersdxlV1.safetensors'
|
|
vae_url = 'https://huggingface.co/ApacheOne/local-checkpoints/blob/main/SDXL(PONY)/VAES/_bothyper.safetensors'
|
|
device = 'cuda'
|
|
|
|
|
|
vae = AutoencoderKL.from_single_file(
|
|
vae_url,
|
|
torch_dtype=torch.float16,
|
|
use_safetensors=True
|
|
)
|
|
|
|
|
|
pipe = StableDiffusionXLPipeline.from_single_file(
|
|
model_url,
|
|
torch_dtype=torch.float16,
|
|
use_safetensors=True,
|
|
variant='fp16',
|
|
vae=vae
|
|
)
|
|
|
|
|
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
|
pipe.scheduler.config,
|
|
use_karras=False
|
|
)
|
|
|
|
|
|
pipe.enable_model_cpu_offload()
|
|
pipe.enable_vae_slicing()
|
|
pipe = pipe.to(device)
|
|
|
|
|
|
prompt = '(masterpiece best quality ultra-detailed best shadow amazing realistic picture) realistic woman, full body, white blackground '
|
|
gen = set_seed(42)
|
|
|
|
|
|
with torch.no_grad():
|
|
with trace(pipe) as tc:
|
|
out = pipe(
|
|
prompt,
|
|
num_inference_steps=9,
|
|
generator=gen,
|
|
callback=tc.time_callback,
|
|
callback_steps=1,
|
|
guidance_scale=1.1,
|
|
height=1024,
|
|
width=1024
|
|
)
|
|
|
|
generated_image_path = os.path.join(output_dir, 'generated_image.png')
|
|
out.images[0].save(generated_image_path)
|
|
|
|
|
|
heat_map = tc.compute_global_heat_map()
|
|
for word in prompt.split():
|
|
word_heat_map = heat_map.compute_word_heat_map(word)
|
|
|
|
|
|
fig = plt.figure()
|
|
word_heat_map.plot_overlay(out.images[0])
|
|
plt.title(f"Heatmap for '{word}'")
|
|
|
|
|
|
heatmap_path = os.path.join(output_dir, f'heatmap_{word}.png')
|
|
plt.savefig(heatmap_path, bbox_inches='tight')
|
|
plt.close(fig)
|
|
|
|
|
|
exp = tc.to_experiment('sdxl-creaprompt-experiment-gpu')
|
|
exp.save()
|
|
|
|
print(f"Generation complete! Images saved in '{output_dir}' folder:")
|
|
print(f"- Generated image: {generated_image_path}")
|
|
print(f"- Heatmaps: {output_dir}/heatmap_<word>.png")
|
|
print("Experiment saved in 'sdxl-creaprompt-experiment-gpu'.") |