Spaces:
Paused
Paused
File size: 5,221 Bytes
b876688 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import torch
import numpy as np
from tqdm import tqdm
from ddpm import DDPMSampler
import logging
from config import Config, default_config
WIDTH = 512
HEIGHT = 512
LATENTS_WIDTH = WIDTH // 8
LATENTS_HEIGHT = HEIGHT // 8
logging.basicConfig(level=logging.INFO)
def generate(
prompt,
uncond_prompt=None,
input_image=None,
config: Config = default_config,
):
with torch.no_grad():
validate_strength(config.diffusion.strength)
generator = initialize_generator(config.seed, config.device.device)
context = encode_prompt(prompt, uncond_prompt, config.diffusion.do_cfg, config.tokenizer, config.models["clip"], config.device.device)
latents = initialize_latents(input_image, config.diffusion.strength, generator, config.models, config.device.device, config.diffusion.sampler_name, config.diffusion.n_inference_steps)
images = run_diffusion(latents, context, config.diffusion.do_cfg, config.diffusion.cfg_scale, config.models, config.device.device, config.diffusion.sampler_name, config.diffusion.n_inference_steps, generator)
return postprocess_images(images)
def validate_strength(strength):
if not 0 < strength <= 1:
raise ValueError("Strength must be between 0 and 1")
def initialize_generator(seed, device):
generator = torch.Generator(device=device)
if seed is None:
generator.seed()
else:
generator.manual_seed(seed)
return generator
def encode_prompt(prompt, uncond_prompt, do_cfg, tokenizer, clip, device):
clip.to(device)
if do_cfg:
cond_tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
cond_context = clip(cond_tokens)
uncond_tokens = tokenizer.batch_encode_plus([uncond_prompt], padding="max_length", max_length=77).input_ids
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
uncond_context = clip(uncond_tokens)
context = torch.cat([cond_context, uncond_context])
else:
tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
tokens = torch.tensor(tokens, dtype=torch.long, device=device)
context = clip(tokens)
return context
def initialize_latents(input_image, strength, generator, models, device, sampler_name, n_inference_steps):
if input_image is None:
# Initialize with random noise
latents = torch.randn((1, 4, 64, 64), generator=generator, device=device)
else:
# Initialize with encoded input image
latents = encode_image(input_image, models, device)
# Add noise based on strength
noise = torch.randn_like(latents, generator=generator)
latents = (1 - strength) * latents + strength * noise
return latents
def preprocess_image(input_image):
input_image_tensor = input_image.resize((WIDTH, HEIGHT))
input_image_tensor = np.array(input_image_tensor)
input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32)
input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
input_image_tensor = input_image_tensor.unsqueeze(0)
input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
return input_image_tensor
def get_sampler(sampler_name, generator, n_inference_steps):
if sampler_name == "ddpm":
sampler = DDPMSampler(generator)
sampler.set_inference_timesteps(n_inference_steps)
else:
raise ValueError(f"Unknown sampler value {sampler_name}.")
return sampler
def run_diffusion(latents, context, do_cfg, cfg_scale, models, device, sampler_name, n_inference_steps, generator):
diffusion = models["diffusion"]
diffusion.to(device)
sampler = get_sampler(sampler_name, generator, n_inference_steps)
timesteps = tqdm(sampler.timesteps)
for timestep in timesteps:
time_embedding = get_time_embedding(timestep).to(device)
model_input = latents.repeat(2, 1, 1, 1) if do_cfg else latents
model_output = diffusion(model_input, context, time_embedding)
if do_cfg:
output_cond, output_uncond = model_output.chunk(2)
model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
latents = sampler.step(timestep, latents, model_output)
decoder = models["decoder"]
decoder.to(device)
images = decoder(latents)
return images
def postprocess_images(images):
images = rescale(images, (-1, 1), (0, 255), clamp=True)
images = images.permute(0, 2, 3, 1)
images = images.to("cpu", torch.uint8).numpy()
return images[0]
def rescale(x, old_range, new_range, clamp=False):
old_min, old_max = old_range
new_min, new_max = new_range
x -= old_min
x *= (new_max - new_min) / (old_max - old_min)
x += new_min
if clamp:
x = x.clamp(new_min, new_max)
return x
def get_time_embedding(timestep):
freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
|