|
"""This file contains the definition of the sampling function.""" |
|
|
|
from typing import Optional, Tuple, List, Text |
|
import tqdm |
|
|
|
import torch |
|
|
|
from .masking import get_masking_ratio |
|
from .factorization import combine_factorized_tokens |
|
|
|
|
|
@torch.no_grad() |
|
def sample( |
|
model, |
|
vqgan_model, |
|
num_samples: int = 10, |
|
labels: Optional[torch.Tensor] = None, |
|
softmax_temperature: float = 1.0, |
|
randomize_temperature: float = 4.5, |
|
mask_schedule_strategy: Text = "linear", |
|
num_steps: int = 12, |
|
guidance_scale: float = 3.0, |
|
mask_token: int = 1024, |
|
patch_size: int = 16, |
|
guidance_annealing: Text = "none", |
|
use_sampling_annealing: bool = False, |
|
scale_pow: float = 4.0, |
|
codebook_size: int = 1024, |
|
codebook_splits: int = 1, |
|
use_tqdm: bool = False, |
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
|
"""Sample from the model. |
|
|
|
Args: |
|
model -> torch.nn.Module: The model to sample from. |
|
vqgan_model -> torch.nn.Module: The VQGAN model. |
|
num_samples -> int: The number of samples to generate. |
|
labels -> Optional[torch.Tensor]: The labels to use for the generation. |
|
softmax_temperature -> float: The temperature for the softmax. |
|
randomize_temperature -> float: The temperature for the randomization. |
|
mask_schedule_strategy -> Text: The strategy for the mask schedule. |
|
num_steps -> int: The number of steps to use for the sampling. |
|
guidance_scale -> float: The scale for the guidance. |
|
mask_token -> int: The token to use for the masking. |
|
patch_size -> int: The size of the patches. |
|
guidance_annealing -> Text: The annealing strategy for the guidance. |
|
use_sampling_annealing -> bool: Whether to use the sampling annealing. |
|
scale_pow -> float: The power for the scaling. |
|
codebook_size -> int: The size of the codebook. |
|
codebook_splits -> int: The number of splits for the codebook. |
|
|
|
Returns: |
|
Tuple[torch.Tensor, List[torch.Tensor]]: The generated samples and the tokens at each step. |
|
""" |
|
device = model.device |
|
|
|
model.eval() |
|
vqgan_model.eval() |
|
|
|
if labels is None: |
|
|
|
labels = [ |
|
1, |
|
7, |
|
282, |
|
604, |
|
724, |
|
179, |
|
751, |
|
404, |
|
850, |
|
torch.randint(0, 999, size=(1,)), |
|
] * (num_samples // 10) |
|
labels = torch.LongTensor(labels).to(device) |
|
|
|
drop_labels = torch.ones(num_samples, dtype=bool, device=device) |
|
spatial_size = int(patch_size**2) |
|
num_splits = int(codebook_splits) |
|
|
|
masked_tokens = torch.full( |
|
(num_samples, spatial_size, num_splits), mask_token, device=device |
|
) |
|
num_maskable = spatial_size * num_splits |
|
mask = masked_tokens == mask_token |
|
|
|
l_full_tokens = [] |
|
gumbel = torch.distributions.Gumbel(loc=0.0, scale=1.0) |
|
|
|
if use_tqdm: |
|
step_iterable = tqdm.tqdm(range(num_steps), desc="Sampling steps", position=1) |
|
else: |
|
step_iterable = range(num_steps) |
|
|
|
for i in step_iterable: |
|
progress = (i + 1) / num_steps |
|
if guidance_scale != 0.0: |
|
logits = model( |
|
torch.cat([masked_tokens.clone(), masked_tokens.clone()], dim=0), |
|
torch.cat([labels, labels], dim=0), |
|
torch.cat([~drop_labels, drop_labels], dim=0), |
|
) |
|
|
|
logits_with_class, logits_without_class = torch.chunk(logits, 2, dim=0) |
|
if guidance_annealing == "none": |
|
scale_step = 1.0 |
|
elif guidance_annealing == "linear": |
|
scale_step = i / num_steps |
|
elif guidance_annealing == "cosine": |
|
scale_pow = torch.ones((1), device=device) * scale_pow |
|
scale_step = ( |
|
(1 - torch.cos(((i / num_steps) ** scale_pow) * torch.pi)) * 1 / 2 |
|
) |
|
scale = guidance_scale * scale_step |
|
logits = logits_with_class + scale * ( |
|
logits_with_class - logits_without_class |
|
) |
|
else: |
|
logits = model(masked_tokens.clone(), labels, ~drop_labels) |
|
|
|
if use_sampling_annealing: |
|
softmax_temperature = 0.5 + 0.8 * (1 - progress) |
|
probabilities = torch.softmax(logits / softmax_temperature, dim=-1) |
|
distribution = torch.distributions.Categorical(probabilities) |
|
predicted_tokens = distribution.sample() |
|
|
|
num_masked = torch.sum(mask, dim=(1, 2))[0] |
|
|
|
predicted_tokens = torch.where(mask, predicted_tokens, masked_tokens) |
|
|
|
confidence = torch.gather( |
|
probabilities, -1, predicted_tokens.unsqueeze(-1) |
|
).squeeze(-1) |
|
|
|
confidence = torch.where(mask, confidence, torch.inf) |
|
|
|
noise = ( |
|
gumbel.sample(predicted_tokens.size()) |
|
* randomize_temperature |
|
* (1 - progress) |
|
) |
|
confidence = torch.log(confidence) + noise.to(device) |
|
|
|
mask_ratio = get_masking_ratio(progress, mode=mask_schedule_strategy).to(device) |
|
|
|
|
|
mask_len = torch.floor(mask_ratio * num_maskable) |
|
num_tokens_to_mask = torch.clamp( |
|
mask_len, torch.ones_like(num_masked), num_masked - 1 |
|
).long() |
|
sorted_confidence = torch.sort(confidence.view(num_samples, -1), dim=-1).values |
|
threshold = sorted_confidence[:, num_tokens_to_mask - 1] |
|
|
|
should_mask = confidence <= threshold.unsqueeze(-1).unsqueeze(-1) |
|
masked_tokens = torch.where(should_mask, mask_token, predicted_tokens) |
|
mask = masked_tokens == mask_token |
|
l_full_tokens.append(predicted_tokens) |
|
|
|
predicted_tokens = combine_factorized_tokens( |
|
predicted_tokens, codebook_size, codebook_splits |
|
) |
|
|
|
generated_image = vqgan_model.decode_tokens(predicted_tokens) |
|
return generated_image, l_full_tokens |
|
|