from __future__ import annotations import torch import torch.nn as nn from monai.utils import optional_import from torch.cuda.amp import autocast tqdm, has_tqdm = optional_import("tqdm", name="tqdm") class Sampler: def __init__(self) -> None: super().__init__() @torch.no_grad() def sampling_fn( self, noise: torch.Tensor, autoencoder_model: nn.Module, diffusion_model: nn.Module, scheduler: nn.Module, prompt_embeds: torch.Tensor, guidance_scale: float = 7.0, scale_factor: float = 0.3, ) -> torch.Tensor: if has_tqdm: progress_bar = tqdm(scheduler.timesteps) else: progress_bar = iter(scheduler.timesteps) for t in progress_bar: noise_input = torch.cat([noise] * 2) model_output = diffusion_model( noise_input, timesteps=torch.Tensor((t,)).to(noise.device).long(), context=prompt_embeds ) noise_pred_uncond, noise_pred_text = model_output.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise, _ = scheduler.step(noise_pred, t, noise) with autocast(): sample = autoencoder_model.decode_stage_2_outputs(noise / scale_factor) return sample