File size: 5,899 Bytes
91fb4ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import math
from typing import Optional, Union

import torch
from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler
from diffusers.training_utils import compute_loss_weighting_for_sd3


# Default values copied from https://github.com/huggingface/diffusers/blob/8957324363d8b239d82db4909fbf8c0875683e3d/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L47
def resolution_dependent_timestep_flow_shift(
    latents: torch.Tensor,
    sigmas: torch.Tensor,
    base_image_seq_len: int = 256,
    max_image_seq_len: int = 4096,
    base_shift: float = 0.5,
    max_shift: float = 1.15,
) -> torch.Tensor:
    image_or_video_sequence_length = 0
    if latents.ndim == 4:
        image_or_video_sequence_length = latents.shape[2] * latents.shape[3]
    elif latents.ndim == 5:
        image_or_video_sequence_length = latents.shape[2] * latents.shape[3] * latents.shape[4]
    else:
        raise ValueError(f"Expected 4D or 5D tensor, got {latents.ndim}D tensor")

    m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
    b = base_shift - m * base_image_seq_len
    mu = m * image_or_video_sequence_length + b
    sigmas = default_flow_shift(latents, sigmas, shift=mu)
    return sigmas


def default_flow_shift(sigmas: torch.Tensor, shift: float = 1.0) -> torch.Tensor:
    sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
    return sigmas


def compute_density_for_timestep_sampling(
    weighting_scheme: str,
    batch_size: int,
    logit_mean: float = None,
    logit_std: float = None,
    mode_scale: float = None,
    device: torch.device = torch.device("cpu"),
    generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
    r"""
    Compute the density for sampling the timesteps when doing SD3 training.

    Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.

    SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
    """
    if weighting_scheme == "logit_normal":
        # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
        u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
        u = torch.nn.functional.sigmoid(u)
    elif weighting_scheme == "mode":
        u = torch.rand(size=(batch_size,), device=device, generator=generator)
        u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
    else:
        u = torch.rand(size=(batch_size,), device=device, generator=generator)
    return u


def get_scheduler_alphas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor:
    if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
        return None
    elif isinstance(scheduler, CogVideoXDDIMScheduler):
        return scheduler.alphas_cumprod.clone()
    else:
        raise ValueError(f"Unsupported scheduler type {type(scheduler)}")


def get_scheduler_sigmas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor:
    if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
        return scheduler.sigmas.clone()
    elif isinstance(scheduler, CogVideoXDDIMScheduler):
        return scheduler.timesteps.clone().float() / float(scheduler.config.num_train_timesteps)
    else:
        raise ValueError(f"Unsupported scheduler type {type(scheduler)}")


def prepare_sigmas(
    scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
    sigmas: torch.Tensor,
    batch_size: int,
    num_train_timesteps: int,
    flow_weighting_scheme: str = "none",
    flow_logit_mean: float = 0.0,
    flow_logit_std: float = 1.0,
    flow_mode_scale: float = 1.29,
    device: torch.device = torch.device("cpu"),
    generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
    if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
        weights = compute_density_for_timestep_sampling(
            weighting_scheme=flow_weighting_scheme,
            batch_size=batch_size,
            logit_mean=flow_logit_mean,
            logit_std=flow_logit_std,
            mode_scale=flow_mode_scale,
            device=device,
            generator=generator,
        )
        indices = (weights * num_train_timesteps).long()
    elif isinstance(scheduler, CogVideoXDDIMScheduler):
        # TODO(aryan): Currently, only uniform sampling is supported. Add more sampling schemes.
        weights = torch.rand(size=(batch_size,), device=device, generator=generator)
        indices = (weights * num_train_timesteps).long()
    else:
        raise ValueError(f"Unsupported scheduler type {type(scheduler)}")

    return sigmas[indices]


def prepare_loss_weights(
    scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
    alphas: Optional[torch.Tensor] = None,
    sigmas: Optional[torch.Tensor] = None,
    flow_weighting_scheme: str = "none",
) -> torch.Tensor:
    if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
        return compute_loss_weighting_for_sd3(sigmas=sigmas, weighting_scheme=flow_weighting_scheme)
    elif isinstance(scheduler, CogVideoXDDIMScheduler):
        # SNR is computed as (alphas / (1 - alphas)), but for some reason CogVideoX uses 1 / (1 - alphas).
        # TODO(aryan): Experiment if using alphas / (1 - alphas) gives better results.
        return 1 / (1 - alphas)
    else:
        raise ValueError(f"Unsupported scheduler type {type(scheduler)}")


def prepare_target(
    scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
    noise: torch.Tensor,
    latents: torch.Tensor,
) -> torch.Tensor:
    if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
        target = noise - latents
    elif isinstance(scheduler, CogVideoXDDIMScheduler):
        target = latents
    else:
        raise ValueError(f"Unsupported scheduler type {type(scheduler)}")

    return target