File size: 9,122 Bytes
954caab b3ee019 954caab 609badf 954caab 575862a b3ee019 954caab 609badf 954caab b3ee019 954caab b3ee019 954caab b3ee019 954caab b3ee019 954caab b3ee019 954caab b3ee019 954caab b3ee019 954caab b3ee019 609badf 954caab b3ee019 954caab b3ee019 954caab 609badf 954caab b3ee019 954caab b3ee019 954caab b3ee019 954caab b3ee019 954caab b3ee019 954caab b3ee019 954caab b3ee019 954caab 609badf 954caab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
from tqdm import tqdm
from icecream import ic
import torch
import torch.nn.functional as F
from diffusers.utils.torch_utils import randn_tensor
@torch.no_grad()
def sample_stage_1(model,
prompt_embeds,
negative_prompt_embeds,
views,
num_inference_steps=100,
guidance_scale=7.0,
reduction='mean',
generator=None):
# Params
num_images_per_prompt = 1
device = model.device
height = model.unet.config.sample_size
width = model.unet.config.sample_size
batch_size = 1 # TODO: Support larger batch sizes, maybe
num_prompts = prompt_embeds.shape[0]
assert num_prompts == len(views), \
"Number of prompts must match number of views!"
# For CFG
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# Setup timesteps
model.scheduler.set_timesteps(int(num_inference_steps), device=device)
timesteps = model.scheduler.timesteps
# Make intermediate_images
noisy_images = model.prepare_intermediate_images(
batch_size * num_images_per_prompt,
model.unet.config.in_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
).to('cuda')
# ic(noisy_images.shape)
for i, t in enumerate(tqdm(timesteps)):
# Apply views to noisy_image
viewed_noisy_images = []
for view_fn in views:
viewed_noisy_images.append(view_fn.view(noisy_images[0]))
viewed_noisy_images = torch.stack(viewed_noisy_images)
# Duplicate inputs for CFG
# Model input is: [ neg_0, neg_1, ..., pos_0, pos_1, ... ]
model_input = torch.cat([viewed_noisy_images] * 2)
model_input = model.scheduler.scale_model_input(model_input, t)
# Predict noise estimate
# print("Predicting noise estimate")
noise_pred = model.unet(
model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=None,
return_dict=False,
)[0]
# ic(noise_pred.shape)
# Extract uncond (neg) and cond noise estimates
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
# ic(noise_pred_uncond.shape)
# Invert the unconditional (negative) estimates
inverted_preds = []
for pred, view in zip(noise_pred_uncond, views):
inverted_pred = view.inverse_view(pred)
inverted_preds.append(inverted_pred)
noise_pred_uncond = torch.stack(inverted_preds)
# ic(noise_pred_uncond.shape)
# Invert the conditional estimates
inverted_preds = []
for pred, view in zip(noise_pred_text, views):
inverted_pred = view.inverse_view(pred)
inverted_preds.append(inverted_pred)
noise_pred_text = torch.stack(inverted_preds)
# ic(noise_pred_text.shape)
# Split into noise estimate and variance estimates
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# ic(noise_pred.shape)
# Reduce predicted noise and variances
noise_pred = noise_pred.view(-1,num_prompts,3,64,64)
predicted_variance = predicted_variance.view(-1,num_prompts,3,64,64)
if reduction == 'mean':
noise_pred = noise_pred.mean(1)
predicted_variance = predicted_variance.mean(1)
elif reduction == 'alternate':
noise_pred = noise_pred[:,i%num_prompts]
predicted_variance = predicted_variance[:,i%num_prompts]
else:
raise ValueError('Reduction must be either `mean` or `alternate`')
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
# ic(noise_pred.shape)
# ic(t.shape)
# ic(t.dtype)
# compute the previous noisy sample x_t -> x_t-1
noisy_images = model.scheduler.step(
noise_pred.to('cuda'), t, noisy_images.to('cuda'), generator=generator, return_dict=False
)[0]
# ic(noisy_images.shape)
# Return denoised images
return noisy_images
@torch.no_grad()
def sample_stage_2(model,
image,
prompt_embeds,
negative_prompt_embeds,
views,
num_inference_steps=100,
guidance_scale=7.0,
reduction='mean',
noise_level=50,
generator=None):
# Params
batch_size = 1 # TODO: Support larger batch sizes, maybe
num_prompts = prompt_embeds.shape[0]
height = model.unet.config.sample_size
width = model.unet.config.sample_size
device = model.device
num_images_per_prompt = 1
# For CFG
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# Get timesteps
model.scheduler.set_timesteps(int(num_inference_steps), device=device)
timesteps = model.scheduler.timesteps
num_channels = model.unet.config.in_channels // 2
noisy_images = model.prepare_intermediate_images(
batch_size * num_images_per_prompt,
num_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
).to('cuda')
# Prepare upscaled image and noise level
image = model.preprocess_image(image, num_images_per_prompt, device)
upscaled = F.interpolate(image.to('cuda'), (height, width), mode="bilinear", align_corners=True).to('cuda')
noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
upscaled = model.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level).to('cuda')
# Condition on noise level, for each model input
noise_level = torch.cat([noise_level] * num_prompts * 2).to('cuda')
# Denoising Loop
for i, t in enumerate(tqdm(timesteps)):
# Cat noisy image with upscaled conditioning image
model_input = torch.cat([noisy_images, upscaled], dim=1).to('cuda')
# Apply views to noisy_image
viewed_inputs = []
for view_fn in views:
viewed_inputs.append(view_fn.view(model_input[0]))
viewed_inputs = torch.stack(viewed_inputs).to('cuda')
# Duplicate inputs for CFG
# Model input is: [ neg_0, neg_1, ..., pos_0, pos_1, ... ]
model_input = torch.cat([viewed_inputs] * 2).to('cuda')
model_input = model.scheduler.scale_model_input(model_input, t).to('cuda')
# predict the noise residual
noise_pred = model.unet(
model_input,
t,
encoder_hidden_states=prompt_embeds,
class_labels=noise_level,
cross_attention_kwargs=None,
return_dict=False,
)[0]
# Extract uncond (neg) and cond noise estimates
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
# Invert the unconditional (negative) estimates
# TODO: pretty sure you can combine these into one loop
inverted_preds = []
for pred, view in zip(noise_pred_uncond, views):
inverted_pred = view.inverse_view(pred)
inverted_preds.append(inverted_pred)
noise_pred_uncond = torch.stack(inverted_preds)
# Invert the conditional estimates
inverted_preds = []
for pred, view in zip(noise_pred_text, views):
inverted_pred = view.inverse_view(pred)
inverted_preds.append(inverted_pred)
noise_pred_text = torch.stack(inverted_preds)
# Split predicted noise and predicted variances
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Combine noise estimates (and variance estimates)
noise_pred = noise_pred.view(-1,num_prompts,3,256,256)
predicted_variance = predicted_variance.view(-1,num_prompts,3,256,256)
if reduction == 'mean':
noise_pred = noise_pred.mean(1)
predicted_variance = predicted_variance.mean(1)
elif reduction == 'alternate':
noise_pred = noise_pred[:,i%num_prompts]
predicted_variance = predicted_variance[:,i%num_prompts]
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
# compute the previous noisy sample x_t -> x_t-1
noisy_images = model.scheduler.step(
noise_pred.to('cuda'), t, noisy_images.to('cuda'), generator=generator, return_dict=False
)[0]
# Return denoised images
return noisy_images |