Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,807 Bytes
5004324 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler
from einops import repeat
import copy
import inspect
import math
import torch
import torch.nn.functional as F
from tqdm import tqdm
class DPMSolver:
def __init__(self, num_timesteps):
self.dpm_solver = DPMSolverMultistepScheduler(
beta_schedule="linear",
prediction_type= "sample",
# algorithm_type="sde-dpmsolver++",
thresholding=False
)
self.dpm_solver.set_timesteps(num_timesteps)
@torch.no_grad()
def pred_noise(self, model, x, t, lowres_img, dtype):
pred_noise = model(x.to(dtype), t.to(dtype), lowres_img=lowres_img.to(dtype))
pred_noise = pred_noise.to(dtype=torch.float32)
return pred_noise
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.dpm_solver.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.dpm_solver.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def get_views(self, panorama_height, panorama_width, window_size=1024, stride=800):
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
# if panorama's height/width < window_size, num_blocks of height/width should return 1
num_blocks_height = round(math.ceil((panorama_height - window_size) / stride)) + 1 if panorama_height > window_size else 1
num_blocks_width = round(math.ceil((panorama_width - window_size) / stride)) + 1 if panorama_width > window_size else 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
views = []
for i in range(total_num_blocks):
h_start = int((i // num_blocks_width) * stride)
h_end = h_start + window_size
if h_end > panorama_height and num_blocks_height > 1:
h_end = panorama_height
h_start = panorama_height - window_size
w_start = int((i % num_blocks_width) * stride)
w_end = w_start + window_size
if w_end > panorama_width and num_blocks_width > 1:
w_end = panorama_width
w_start = panorama_width - window_size
views.append((h_start, h_end, w_start, w_end))
return views
def generate_panorama(self, height, width, device, dtype, num_inference_steps,
unet, lowres_img, view_batch_size=15, eta=0, seed=0):
# 6. Define panorama grid and initialize views for synthesis.
# prepare batch grid
views = self.get_views(height, width)
views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
views_scheduler_status = [copy.deepcopy(self.dpm_solver.__dict__)] * len(views_batch)
shape = (1, 3, height, width)
count = torch.zeros(*shape, device=device)
value = torch.zeros(*shape, device=device)
generator = torch.Generator(device=device)
if seed is not None:
generator = generator.manual_seed(seed)
img = torch.randn(*shape, device=device, generator=generator)
up_lowres_img = F.interpolate(lowres_img, (shape[2], shape[3]), mode="bilinear")
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
# Each denoising step also includes refinement of the latents with respect to the
# views.
timesteps = self.dpm_solver.timesteps
num_warmup_steps = len(timesteps) - num_inference_steps * self.dpm_solver.order
for i, time in tqdm(enumerate(self.dpm_solver.timesteps)):
count.zero_()
value.zero_()
# generate views
# Here, we iterate through different spatial crops of the latents and denoise them. These
# denoised (latent) crops are then averaged to produce the final latent
# for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
# MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
# Batch views denoise
for j, batch_view in enumerate(views_batch):
vb_size = len(batch_view)
# get the latents corresponding to the current view coordinates
img_for_view = torch.cat(
[
img[:, :, h_start:h_end, w_start:w_end]
for h_start, h_end, w_start, w_end in batch_view
]
)
lowres_img_for_view = torch.cat(
[
up_lowres_img[:, :, h_start:h_end, w_start:w_end]
for h_start, h_end, w_start, w_end in batch_view
]
)
# rematch block's scheduler status
self.dpm_solver.__dict__.update(views_scheduler_status[j])
t = torch.tensor([time] * img_for_view.shape[0], device=device)
pred_noise = self.pred_noise(
unet, img_for_view, t, lowres_img_for_view, dtype
)
img_denoised_batch = self.dpm_solver.step(pred_noise, time, img_for_view, **extra_step_kwargs).prev_sample
# save views scheduler status after sample
views_scheduler_status[j] = copy.deepcopy(self.dpm_solver.__dict__)
# extract value from batch
for img_view_denoised, (h_start, h_end, w_start, w_end) in zip(
img_denoised_batch.chunk(vb_size), batch_view
):
value[:, :, h_start:h_end, w_start:w_end] += img_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
img = torch.where(count > 0, value / count, value)
return img |