|
|
|
import diffusers |
|
import transformers |
|
|
|
import torch |
|
from typing import Union |
|
import os |
|
import warnings |
|
import numpy as np |
|
from PIL import Image |
|
import tqdm |
|
from copy import deepcopy |
|
import matplotlib.pyplot as plt |
|
|
|
def build_generator( |
|
device : torch.device, |
|
seed : int, |
|
): |
|
""" |
|
Build a torch.Generator with a given seed. |
|
""" |
|
generator = torch.Generator(device).manual_seed(seed) |
|
return generator |
|
|
|
def load_stablediffusion_model( |
|
model_id : Union[str, os.PathLike], |
|
device : torch.device, |
|
): |
|
""" |
|
Load a complete diffusion model from a model id. |
|
Returns a tuple of the model and a torch.Generator if seed is not None. |
|
|
|
""" |
|
pipe = diffusers.DiffusionPipeline.from_pretrained( |
|
model_id, |
|
revision="fp16", |
|
torch_dtype=torch.float16, |
|
use_auth_token=True, |
|
) |
|
pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
|
try: |
|
pipe = pipe.to(device) |
|
except: |
|
warnings.warn( |
|
f'Could not load model to device:{device}. Using CPU instead.' |
|
) |
|
pipe = pipe.to('cpu') |
|
device = 'cpu' |
|
|
|
return pipe |
|
|
|
|
|
def visualize_image_grid( |
|
imgs : np.array, |
|
rows : int, |
|
cols : int): |
|
|
|
assert len(imgs) == rows*cols |
|
|
|
|
|
w, h = imgs[0].size |
|
|
|
grid = Image.new('RGB', size=(cols*w, rows*h)) |
|
|
|
for i,img in enumerate(imgs): |
|
grid.paste(img, box=(i%cols*w, i//cols*h)) |
|
return grid |
|
|
|
|
|
def build_pipeline( |
|
autoencoder : Union[str, os.PathLike] = "CompVis/stable-diffusion-v1-4", |
|
tokenizer : Union[str, os.PathLike] = "openai/clip-vit-large-patch14", |
|
text_encoder : Union[str, os.PathLike] = "openai/clip-vit-large-patch14", |
|
unet : Union[str, os.PathLike] = "CompVis/stable-diffusion-v1-4", |
|
device : torch.device = torch.device('cuda'), |
|
): |
|
""" |
|
Create a pipeline for StableDiffusion by loading the model and component seperetely. |
|
Arguments: |
|
autoencoder: path to model that autoencoder will be loaded from |
|
tokenizer: path to tokenizer |
|
text_encoder: path to text_encoder |
|
unet: path to unet |
|
""" |
|
|
|
vae = diffusers.AutoencoderKL.from_pretrained(autoencoder, subfolder = 'vae') |
|
|
|
|
|
tokenizer = transformers.CLIPTokenizer.from_pretrained(tokenizer) |
|
text_encoder = transformers.CLIPTextModel.from_pretrained(text_encoder) |
|
|
|
|
|
unet = diffusers.UNet2DConditionModel.from_pretrained(unet, subfolder = 'unet') |
|
|
|
|
|
vae = vae.to(device) |
|
text_encoder = text_encoder.to(device) |
|
unet = unet.to(device) |
|
|
|
return vae, tokenizer, text_encoder, unet |
|
|
|
|
|
def custom_stablediffusion_inference( |
|
vae, |
|
tokenizer, |
|
text_encoder, |
|
unet, |
|
noise_scheduler, |
|
prompt : list, |
|
device : torch.device, |
|
num_inference_steps = 100, |
|
image_size = (512,512), |
|
guidance_scale = 8, |
|
seed = 42, |
|
return_image_step = 5, |
|
): |
|
|
|
if isinstance(prompt,str): |
|
prompt = [prompt] |
|
|
|
batch_size = len(prompt) |
|
text_input = tokenizer( |
|
prompt, |
|
padding = 'max_length', |
|
truncation = True, |
|
max_length = tokenizer.model_max_length, |
|
return_tensors = 'pt').to(device) |
|
|
|
text_embeddings = text_encoder( |
|
text_input.input_ids.to(device) |
|
)[0] |
|
|
|
|
|
max_length = text_input.input_ids.shape[-1] |
|
empty = [""] * batch_size |
|
uncond_input = tokenizer( |
|
empty, |
|
padding = 'max_length', |
|
truncation = True, |
|
max_length = max_length, |
|
return_tensors = 'pt').to(device) |
|
|
|
uncond_embeddings = text_encoder( |
|
uncond_input.input_ids.to(device) |
|
)[0] |
|
|
|
|
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
|
|
|
|
|
latents = torch.randn( |
|
(1, unet.in_channels, image_size[0] // 8, image_size[1] // 8), |
|
generator=torch.manual_seed(seed) if seed is not None else None |
|
) |
|
print(latents.shape) |
|
|
|
latents = latents.to(device) |
|
|
|
|
|
noise_scheduler.set_timesteps(num_inference_steps) |
|
|
|
latents = latents * noise_scheduler.init_noise_sigma |
|
|
|
noise_scheduler.set_timesteps(num_inference_steps) |
|
for i,t in tqdm.tqdm(enumerate(noise_scheduler.timesteps)): |
|
|
|
latent_model_input = torch.cat([latents] * 2) |
|
|
|
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) |
|
|
|
with torch.no_grad(): |
|
|
|
noise_pred = unet(latent_model_input, t, encoder_hidden_states = text_embeddings).sample |
|
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
|
|
if i % return_image_step == 0: |
|
with torch.no_grad(): |
|
latents_copy = deepcopy(latents) |
|
image = vae.decode(1/0.18215 * latents_copy).sample |
|
|
|
image = (image / 2 + 0.5).clamp(0,1) |
|
image = image.detach().cpu().permute(0,2,3,1).numpy() |
|
images = (image * 255).round().astype("uint8") |
|
|
|
pil_images = [Image.fromarray(img) for img in images] |
|
|
|
yield pil_images[0] |
|
|
|
yield pil_images[0] |
|
|
|
if __name__ == "__main__": |
|
device = torch.device("cpu") |
|
model_id = "stabilityai/stable-diffusion-2-1" |
|
tokenizer_id = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" |
|
|
|
noise_scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(model_id,subfolder="scheduler") |
|
prompt = "A Hyperrealistic photograph of Italian architectural modern home in Italy, lens flares,\ |
|
cinematic, hdri, matte painting, concept art, celestial, soft render, highly detailed, octane\ |
|
render, architectural HD, HQ, 4k, 8k" |
|
|
|
vae, tokenizer, text_encoder, unet = build_pipeline( |
|
autoencoder = model_id, |
|
tokenizer=tokenizer_id, |
|
text_encoder=tokenizer_id, |
|
unet=model_id, |
|
device=device, |
|
) |
|
image_iter = custom_stablediffusion_inference(vae, tokenizer, text_encoder, unet, noise_scheduler, prompt = prompt, device=device, seed = None) |
|
for i, image in enumerate(image_iter): |
|
image.save(f"step_{i}.png") |
|
|
|
|