import torch class DDPMScheduler: def __init__( self, max_timesteps: int = 1000, beta_1: int = 0.0001, beta_2: int = 0.02 ) -> None: self.beta_1 = beta_1 self.beta_2 = beta_2 self.max_timesteps = max_timesteps self._init_params() def _init_params(self, timesteps: int | None = None): self.beta = torch.linspace(self.beta_1, self.beta_2, timesteps or self.max_timesteps) self.sqrt_beta = torch.sqrt(self.beta) self.alpha = (1 - self.beta) self.sqrt_alpha = torch.sqrt(self.alpha) self.alpha_hat = torch.cumprod(1 - self.beta, dim=0) self.sqrt_alpha_hat = torch.sqrt(self.alpha_hat) self.sqrt_one_minus_alpha = torch.sqrt(1 - self.alpha) self.sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat) def noising( self, x_0: torch.Tensor, t: torch.Tensor ): if t.device != x_0.device: t = t.to(x_0.device) noise = torch.randn_like(x_0, device=x_0.device) new_x = self.sqrt_alpha_hat.to(x_0.device)[t][:, None, None, None] * x_0 new_noise = self.sqrt_one_minus_alpha_hat.to(x_0.device)[t][:, None, None, None] * noise return new_x + new_noise, noise @torch.no_grad() def sampling_t( self, x_t: torch.Tensor, model, labels: torch.Tensor, timesteps: int, t: int, n_samples: int = 16, cfg_scale: int = 3, ): time = torch.full((n_samples,), fill_value=t, device=model.device) pred_noise = model(x_t, time, labels) if cfg_scale > 0 and labels is not None: uncond_pred_noise = model(x_t, time, None) pred_noise = torch.lerp(uncond_pred_noise, pred_noise, cfg_scale) alpha = self.alpha.to(model.device)[time][:, None, None, None] sqrt_alpha = self.sqrt_alpha.to(model.device)[time][:, None, None, None] somah = self.sqrt_one_minus_alpha_hat.to(model.device)[time][:, None, None, None] sqrt_beta = self.sqrt_beta.to(model.device)[time][:, None, None, None] if t > 1: noise = torch.randn_like(x_t, device=model.device) else: noise = torch.zeros_like(x_t, device=model.device) x_t_new = 1 / sqrt_alpha * (x_t - (1-alpha) / somah * pred_noise) + sqrt_beta * noise return x_t_new.clamp(-1, 1) @torch.no_grad() def sampling( self, model, n_samples: int = 16, in_channels: int = 3, dim: int = 32, timesteps: int = 1000, cfg_scale: int = 3, labels=None, *args, **kwargs ): if labels is not None: n_samples = labels.shape[0] model.eval() x_t = torch.randn( n_samples, in_channels, dim, dim, device=model.device ) step_ratios = self.max_timesteps // timesteps all_timesteps = torch.flip(torch.arange(0, timesteps) * step_ratios, dims=(0,)) for t in all_timesteps: x_t = self.sampling_t(x_t=x_t, model=model, labels=labels, t=t, timesteps=timesteps, n_samples=n_samples, cfg_scale=cfg_scale) model.train() x_t = (x_t.clamp(-1, 1) + 1) / 2 * 255. # range [0,255] return x_t.type(torch.uint8) @torch.no_grad() def sampling_demo( self, model, n_samples: int = 16, in_channels: int = 3, dim: int = 32, timesteps: int = 1000, cfg_scale: int = 3, labels=None, *args, **kwargs ): if labels is not None: n_samples = labels.shape[0] x_t = torch.randn( n_samples, in_channels, dim, dim, device=model.device ) model.eval() step_ratios = self.max_timesteps // timesteps all_timesteps = torch.flip(torch.arange(0, timesteps) * step_ratios, dims=(0,)) for t in all_timesteps: x_t = self.sampling_t(x_t=x_t, model=model, labels=labels, t=t, timesteps=timesteps, n_samples=n_samples, cfg_scale=cfg_scale) yield ((x_t.clamp(-1, 1) + 1) / 2 * 255).type(torch.uint8) class DDIMScheduler(DDPMScheduler): def __init__( self, max_timesteps: int = 1000, beta_1: int = 0.0001, beta_2: int = 0.02 ) -> None: super().__init__(beta_1=beta_1, beta_2=beta_2, max_timesteps=max_timesteps) self._init_params() def _init_params(self, timesteps: int | None = None): self.beta = torch.linspace(self.beta_1, self.beta_2, timesteps or self.max_timesteps) self.sqrt_beta = torch.sqrt(self.beta) self.alpha = (1 - self.beta) self.sqrt_alpha = torch.sqrt(self.alpha) self.alpha_hat = torch.cumprod(1 - self.beta, dim=0) self.sqrt_alpha_hat = torch.sqrt(self.alpha_hat) self.sqrt_one_minus_alpha = torch.sqrt(1 - self.alpha) self.sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat) self.alpha_hat_prev = torch.cat([torch.tensor([1.]), self.alpha_hat], dim=0)[:-1] self.variance = (1 - self.alpha_hat_prev) / (1 - self.alpha_hat) * \ (1 - self.alpha_hat / self.alpha_hat_prev) @torch.no_grad() def sampling_t( self, x_t: torch.Tensor, model, t: int, timesteps: int, labels: torch.Tensor | None = None, n_samples: int = 16, eta: float = 0.0, *args, **kwargs ): time = torch.full((n_samples,), fill_value=t, device=model.device) time_prev = time - self.max_timesteps // timesteps pred_noise = model(x_t, time, labels) sqrt_one_minus_alpha_hat = self.sqrt_one_minus_alpha_hat.to(model.device)[time][:, None, None, None] sqrt_alpha_hat = self.sqrt_alpha_hat.to(model.device)[time][:, None, None, None] alpha_hat_prev = self.alpha_hat[time_prev] if time_prev[0] >= 0 else torch.ones_like(time_prev) alpha_hat_prev = alpha_hat_prev.to(model.device)[:, None, None, None] sqrt_alpha_hat_prev = torch.sqrt(alpha_hat_prev) posterior_std = torch.sqrt(self.variance)[time][:, None, None, None] * eta if t > 0: noise = torch.randn_like(x_t, device=model.device) else: noise = torch.zeros_like(x_t, device=model.device) x_0_pred = (x_t - sqrt_one_minus_alpha_hat * pred_noise) / sqrt_alpha_hat x_0_pred = x_0_pred.clamp(-1, 1) x_t_direction = torch.sqrt(1. - alpha_hat_prev - posterior_std**2) * pred_noise random_noise = posterior_std * noise x_t_1 = sqrt_alpha_hat_prev * x_0_pred + x_t_direction + random_noise return x_t_1 if __name__ == "__main__": dct = DDIMScheduler().__dict__ for k in dct.keys(): if isinstance(dct[k], torch.Tensor): print(k, dct[k].shape) else: print(k, dct[k])