Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
import clip | |
import numpy as np | |
import requests | |
import torch | |
import torchvision.transforms as transforms | |
import torchvision.utils as vutils | |
from diffusers import AutoencoderKL | |
from torch import Tensor | |
from tqdm import tqdm | |
from denoiser import Denoiser | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
to_pil = transforms.ToPILImage() | |
class DiffusionGenerator: | |
model: Denoiser | |
vae: AutoencoderKL | |
device: torch.device | |
model_dtype: torch.dtype = torch.float32 | |
def generate( | |
self, | |
labels: Tensor, # embeddings to condition on | |
n_iter: int = 30, | |
num_imgs: int = 16, | |
class_guidance: float = 3, | |
seed: int = 10, | |
scale_factor: int = 8, # latent scaling before decoding - should be ~ std of latent space | |
img_size: int = 32, # height, width of latent | |
sharp_f: float = 0.1, | |
bright_f: float = 0.1, | |
exponent: float = 1, | |
seeds: Tensor | None = None, | |
noise_levels=None, | |
use_ddpm_plus: bool = True, | |
): | |
"""Generate images via reverse diffusion. | |
if use_ddpm_plus=True uses Algorithm 2 DPM-Solver++(2M) here: https://arxiv.org/pdf/2211.01095.pdf | |
else use ddim with alpha = 1-sigma | |
""" | |
if noise_levels is None: | |
noise_levels = (1 - torch.pow(torch.arange(0, 1, 1 / n_iter), exponent)).tolist() | |
noise_levels[0] = 0.99 | |
if use_ddpm_plus: | |
lambdas = [np.log((1 - sigma) / sigma) for sigma in noise_levels] # log snr | |
hs = [lambdas[i] - lambdas[i - 1] for i in range(1, len(lambdas))] | |
rs = [hs[i - 1] / hs[i] for i in range(1, len(hs))] | |
x_t = self.initialize_image(seeds, num_imgs, img_size, seed) | |
labels = torch.cat([labels, torch.zeros_like(labels)]) | |
self.model.eval() | |
x0_pred_prev = None | |
for i in tqdm(range(len(noise_levels) - 1)): | |
curr_noise, next_noise = noise_levels[i], noise_levels[i + 1] | |
x0_pred = self.pred_image(x_t, labels, curr_noise, class_guidance) | |
if x0_pred_prev is None: | |
x_t = ((curr_noise - next_noise) * x0_pred + next_noise * x_t) / curr_noise | |
else: | |
if use_ddpm_plus: | |
# x0_pred is a combination of the two previous x0_pred: | |
D = (1 + 1 / (2 * rs[i - 1])) * x0_pred - (1 / (2 * rs[i - 1])) * x0_pred_prev | |
else: | |
# ddim: | |
D = x0_pred | |
x_t = ((curr_noise - next_noise) * D + next_noise * x_t) / curr_noise | |
x0_pred_prev = x0_pred | |
x0_pred = self.pred_image(x_t, labels, next_noise, class_guidance) | |
# shifting latents works a bit like an image editor: | |
x0_pred[:, 3, :, :] += sharp_f | |
x0_pred[:, 0, :, :] += bright_f | |
x0_pred_img = self.vae.decode((x0_pred * scale_factor).to(self.model_dtype))[0].cpu() | |
return x0_pred_img, x0_pred | |
def pred_image(self, noisy_image, labels, noise_level, class_guidance): | |
num_imgs = noisy_image.size(0) | |
noises = torch.full((2 * num_imgs, 1), noise_level) | |
x0_pred = self.model( | |
torch.cat([noisy_image, noisy_image]), | |
noises.to(self.device, self.model_dtype), | |
labels.to(self.device, self.model_dtype), | |
) | |
x0_pred = self.apply_classifier_free_guidance(x0_pred, num_imgs, class_guidance) | |
return x0_pred | |
def initialize_image(self, seeds, num_imgs, img_size, seed): | |
"""Initialize the seed tensor.""" | |
if seeds is None: | |
generator = torch.Generator(device=self.device) | |
generator.manual_seed(seed) | |
return torch.randn( | |
num_imgs, | |
4, | |
img_size, | |
img_size, | |
dtype=self.model_dtype, | |
device=self.device, | |
generator=generator, | |
) | |
else: | |
return seeds.to(self.device, self.model_dtype) | |
def apply_classifier_free_guidance(self, x0_pred, num_imgs, class_guidance): | |
"""Apply classifier-free guidance to the predictions.""" | |
x0_pred_label, x0_pred_no_label = x0_pred[:num_imgs], x0_pred[num_imgs:] | |
return class_guidance * x0_pred_label + (1 - class_guidance) * x0_pred_no_label | |
class LTDConfig: | |
vae_scale_factor: float = 8 | |
img_size: int = 32 | |
model_dtype: torch.dtype = torch.float32 | |
file_url: str = None # = "https://huggingface.co/apapiu/small_ldt/resolve/main/state_dict_378000.pth" | |
local_filename: str = "state_dict_378000.pth" | |
vae_name: str = "ByteDance/SDXL-Lightning" | |
clip_model_name: str = "ViT-L/14" | |
denoiser: Denoiser = Denoiser( | |
image_size=32, | |
noise_embed_dims=256, | |
patch_size=2, | |
embed_dim=256, | |
dropout=0, | |
n_layers=4, | |
) | |
def download_file(url, filename): | |
with requests.get(url, stream=True) as r: | |
r.raise_for_status() | |
with open(filename, "wb") as f: | |
for chunk in r.iter_content(chunk_size=8192): | |
f.write(chunk) | |
def encode_text(label, model): | |
text_tokens = clip.tokenize(label, truncate=True).to(device) | |
text_encoding = model.encode_text(text_tokens) | |
return text_encoding.cpu() | |
class DiffusionTransformer: | |
def __init__(self, config: LTDConfig): | |
denoiser = config.denoiser.to(config.model_dtype) | |
if config.file_url is not None: | |
print(f"Downloading model from {config.file_url}") | |
download_file(config.file_url, config.local_filename) | |
state_dict = torch.load(config.local_filename, map_location=torch.device("cpu")) | |
denoiser.load_state_dict(state_dict) | |
denoiser = denoiser.to(device) | |
vae = AutoencoderKL.from_pretrained(config.vae_name, torch_dtype=config.model_dtype).to(device) | |
self.clip_model, preprocess = clip.load(config.clip_model_name) | |
self.clip_model = self.clip_model.to(device) | |
self.diffuser = DiffusionGenerator(denoiser, vae, device, config.model_dtype) | |
def generate_image_from_text( | |
self, prompt: str, class_guidance=6, seed=11, num_imgs=1, img_size=32, n_iter=15 | |
): | |
nrow = int(np.sqrt(num_imgs)) | |
cur_prompts = [prompt] * num_imgs | |
labels = encode_text(cur_prompts, self.clip_model) | |
out, out_latent = self.diffuser.generate( | |
labels=labels, | |
num_imgs=num_imgs, | |
class_guidance=class_guidance, | |
seed=seed, | |
n_iter=n_iter, | |
exponent=1, | |
scale_factor=8, | |
sharp_f=0, | |
bright_f=0, | |
) | |
out = to_pil((vutils.make_grid((out + 1) / 2, nrow=nrow, padding=4)).float().clip(0, 1)) | |
return out | |