File size: 5,261 Bytes
80ebcb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from typing import Any, Dict, Optional, Tuple

import diffusers
import torch
from diffusers import LTXVideoTransformer3DModel
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils.import_utils import is_torch_version


def patch_transformer_forward() -> None:
    _perform_ltx_transformer_forward_patch()


def patch_apply_rotary_emb_for_tp_compatibility() -> None:
    _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch()


def _perform_ltx_transformer_forward_patch() -> None:
    LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3Dforward


def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
    def apply_rotary_emb(x, freqs):
        cos, sin = freqs
        # ======== THIS IS CHANGED FROM THE ORIGINAL IMPLEMENTATION ========
        # The change is made due to unsupported DTensor operation aten.ops.unbind
        # FIXME: Once aten.ops.unbind support lands, this will no longer be required
        # x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1)  # [B, S, H, D // 2]
        x_real, x_imag = x.unflatten(2, (-1, 2)).chunk(2, dim=-1)  # [B, S, H, D // 2]
        # ==================================================================
        x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
        out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
        return out

    diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb


def _patched_LTXVideoTransformer3Dforward(
    self,
    hidden_states: torch.Tensor,
    encoder_hidden_states: torch.Tensor,
    timestep: torch.LongTensor,
    encoder_attention_mask: torch.Tensor,
    num_frames: int,
    height: int,
    width: int,
    rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
    return_dict: bool = True,
    *args,
    **kwargs,
) -> torch.Tensor:
    image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)

    # convert encoder_attention_mask to a bias the same way we do for attention_mask
    if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
        encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
        encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

    batch_size = hidden_states.size(0)

    # ===== This is modified compared to Diffusers =====
    # This is done because the Diffusers pipeline will pass in a 1D tensor for timestep
    if timestep.ndim == 1:
        timestep = timestep.view(-1, 1, 1).expand(-1, *hidden_states.shape[1:-1], -1)
    # ==================================================

    temb, embedded_timestep = self.time_embed(
        timestep.flatten(),
        batch_size=batch_size,
        hidden_dtype=hidden_states.dtype,
    )

    # ===== This is modified compared to Diffusers =====
    # temb = temb.view(batch_size, -1, temb.size(-1))
    # embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
    # ==================================================
    # This is done to make it possible to use per-token timestep embedding
    temb = temb.view(batch_size, *hidden_states.shape[1:-1], temb.size(-1))
    embedded_timestep = embedded_timestep.view(batch_size, *hidden_states.shape[1:-1], embedded_timestep.size(-1))
    # ==================================================

    hidden_states = self.proj_in(hidden_states)

    encoder_hidden_states = self.caption_projection(encoder_hidden_states)
    encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))

    for block in self.transformer_blocks:
        if torch.is_grad_enabled() and self.gradient_checkpointing:

            def create_custom_forward(module, return_dict=None):
                def custom_forward(*inputs):
                    if return_dict is not None:
                        return module(*inputs, return_dict=return_dict)
                    else:
                        return module(*inputs)

                return custom_forward

            ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
            hidden_states = torch.utils.checkpoint.checkpoint(
                create_custom_forward(block),
                hidden_states,
                encoder_hidden_states,
                temb,
                image_rotary_emb,
                encoder_attention_mask,
                **ckpt_kwargs,
            )
        else:
            hidden_states = block(
                hidden_states=hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                temb=temb,
                image_rotary_emb=image_rotary_emb,
                encoder_attention_mask=encoder_attention_mask,
            )

    scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
    shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]

    hidden_states = self.norm_out(hidden_states)
    hidden_states = hidden_states * (1 + scale) + shift
    output = self.proj_out(hidden_states)

    if not return_dict:
        return (output,)
    return Transformer2DModelOutput(sample=output)