Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from PIL import Image | |
def view_images(images, num_rows=1, offset_ratio=0.02): | |
if type(images) is list: | |
num_empty = len(images) % num_rows | |
elif images.ndim == 4: | |
num_empty = images.shape[0] % num_rows | |
else: | |
images = [images] | |
num_empty = 0 | |
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 | |
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty | |
num_items = len(images) | |
h, w, c = images[0].shape | |
offset = int(h * offset_ratio) | |
num_cols = num_items // num_rows | |
image_ = np.ones((h * num_rows + offset * (num_rows - 1), | |
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 | |
for i in range(num_rows): | |
for j in range(num_cols): | |
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ | |
i * num_cols + j] | |
pil_img = Image.fromarray(image_) | |
return pil_img | |
def diffusion_step(model, latents, context, t, guidance_scale, low_resource=False): | |
if low_resource: | |
noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] | |
noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] | |
else: | |
latents_input = torch.cat([latents] * 2) | |
noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] | |
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) | |
latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] | |
return latents | |
def latent2image(vae, latents): | |
latents = 1 / 0.18215 * latents | |
image = vae.decode(latents)['sample'] | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).numpy() | |
image = (image * 255).astype(np.uint8) | |
return image | |
def init_latent(latent, model, height, width, generator, batch_size): | |
if latent is None: | |
latent = torch.randn( | |
(1, model.unet.in_channels, height // 8, width // 8), | |
generator=generator, | |
) | |
latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) | |
return latent, latents | |
def text2image_ldm_stable( | |
model, | |
prompt, | |
num_inference_steps = 50, | |
guidance_scale = 7.5, | |
generator = None, | |
latent = None, | |
low_resource = False, | |
): | |
height = width = 512 | |
batch_size = len(prompt) | |
text_input = model.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=model.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = model.tokenizer( | |
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" | |
) | |
uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] | |
context = [uncond_embeddings, text_embeddings] | |
if not low_resource: | |
context = torch.cat(context) | |
latent, latents = init_latent(latent, model, height, width, generator, batch_size) | |
model.scheduler.set_timesteps(num_inference_steps) | |
for t in model.scheduler.timesteps: | |
latents = diffusion_step(model, latents, context, t, guidance_scale, low_resource) | |
image = latent2image(model.vae, latents) | |
image, _ = model.run_safety_checker(image=image, device=model.device, dtype=text_embeddings.dtype) | |
return image |