File size: 6,126 Bytes
14ce5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""This file contains the definition of the sampling function."""

from typing import Optional, Tuple, List, Text
import tqdm

import torch

from .masking import get_masking_ratio
from .factorization import combine_factorized_tokens


@torch.no_grad()
def sample(
    model,
    vqgan_model,
    num_samples: int = 10,
    labels: Optional[torch.Tensor] = None,
    softmax_temperature: float = 1.0,
    randomize_temperature: float = 4.5,
    mask_schedule_strategy: Text = "linear",
    num_steps: int = 12,
    guidance_scale: float = 3.0,
    mask_token: int = 1024,
    patch_size: int = 16,
    guidance_annealing: Text = "none",
    use_sampling_annealing: bool = False,
    scale_pow: float = 4.0,
    codebook_size: int = 1024,
    codebook_splits: int = 1,
    use_tqdm: bool = False,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
    """Sample from the model.

    Args:
        model -> torch.nn.Module: The model to sample from.
        vqgan_model -> torch.nn.Module: The VQGAN model.
        num_samples -> int: The number of samples to generate.
        labels -> Optional[torch.Tensor]: The labels to use for the generation.
        softmax_temperature -> float: The temperature for the softmax.
        randomize_temperature -> float: The temperature for the randomization.
        mask_schedule_strategy -> Text: The strategy for the mask schedule.
        num_steps -> int: The number of steps to use for the sampling.
        guidance_scale -> float: The scale for the guidance.
        mask_token -> int: The token to use for the masking.
        patch_size -> int: The size of the patches.
        guidance_annealing -> Text: The annealing strategy for the guidance.
        use_sampling_annealing -> bool: Whether to use the sampling annealing.
        scale_pow -> float: The power for the scaling.
        codebook_size -> int: The size of the codebook.
        codebook_splits -> int: The number of splits for the codebook.

    Returns:
        Tuple[torch.Tensor, List[torch.Tensor]]: The generated samples and the tokens at each step.
    """
    device = model.device

    model.eval()
    vqgan_model.eval()

    if labels is None:
        # goldfish, chicken, tiger cat, hourglass, ship, dog, race car, airliner, teddy bear, random
        labels = [
            1,
            7,
            282,
            604,
            724,
            179,
            751,
            404,
            850,
            torch.randint(0, 999, size=(1,)),
        ] * (num_samples // 10)
        labels = torch.LongTensor(labels).to(device)

    drop_labels = torch.ones(num_samples, dtype=bool, device=device)
    spatial_size = int(patch_size**2)
    num_splits = int(codebook_splits)

    masked_tokens = torch.full(
        (num_samples, spatial_size, num_splits), mask_token, device=device
    )
    num_maskable = spatial_size * num_splits
    mask = masked_tokens == mask_token

    l_full_tokens = []
    gumbel = torch.distributions.Gumbel(loc=0.0, scale=1.0)

    if use_tqdm:
        step_iterable = tqdm.tqdm(range(num_steps), desc="Sampling steps", position=1)
    else:
        step_iterable = range(num_steps)

    for i in step_iterable:
        progress = (i + 1) / num_steps
        if guidance_scale != 0.0:
            logits = model(
                torch.cat([masked_tokens.clone(), masked_tokens.clone()], dim=0),
                torch.cat([labels, labels], dim=0),
                torch.cat([~drop_labels, drop_labels], dim=0),
            )
            # Classifier-free guidance
            logits_with_class, logits_without_class = torch.chunk(logits, 2, dim=0)
            if guidance_annealing == "none":
                scale_step = 1.0
            elif guidance_annealing == "linear":
                scale_step = i / num_steps
            elif guidance_annealing == "cosine":
                scale_pow = torch.ones((1), device=device) * scale_pow
                scale_step = (
                    (1 - torch.cos(((i / num_steps) ** scale_pow) * torch.pi)) * 1 / 2
                )  # power-cos scaling
            scale = guidance_scale * scale_step
            logits = logits_with_class + scale * (
                logits_with_class - logits_without_class
            )
        else:
            logits = model(masked_tokens.clone(), labels, ~drop_labels)

        if use_sampling_annealing:
            softmax_temperature = 0.5 + 0.8 * (1 - progress)
        probabilities = torch.softmax(logits / softmax_temperature, dim=-1)
        distribution = torch.distributions.Categorical(probabilities)
        predicted_tokens = distribution.sample()

        num_masked = torch.sum(mask, dim=(1, 2))[0]

        predicted_tokens = torch.where(mask, predicted_tokens, masked_tokens)

        confidence = torch.gather(
            probabilities, -1, predicted_tokens.unsqueeze(-1)
        ).squeeze(-1)
        # Ignore existing tokens by overwriting the confidence.
        confidence = torch.where(mask, confidence, torch.inf)

        noise = (
            gumbel.sample(predicted_tokens.size())
            * randomize_temperature
            * (1 - progress)
        )
        confidence = torch.log(confidence) + noise.to(device)

        mask_ratio = get_masking_ratio(progress, mode=mask_schedule_strategy).to(device)

        # min = 1, max = num_masked - 1
        mask_len = torch.floor(mask_ratio * num_maskable)
        num_tokens_to_mask = torch.clamp(
            mask_len, torch.ones_like(num_masked), num_masked - 1
        ).long()
        sorted_confidence = torch.sort(confidence.view(num_samples, -1), dim=-1).values
        threshold = sorted_confidence[:, num_tokens_to_mask - 1]

        should_mask = confidence <= threshold.unsqueeze(-1).unsqueeze(-1)
        masked_tokens = torch.where(should_mask, mask_token, predicted_tokens)
        mask = masked_tokens == mask_token
        l_full_tokens.append(predicted_tokens)

    predicted_tokens = combine_factorized_tokens(
        predicted_tokens, codebook_size, codebook_splits
    )

    generated_image = vqgan_model.decode_tokens(predicted_tokens)
    return generated_image, l_full_tokens