lehduong commited on
Commit
ac179f8
·
verified ·
1 Parent(s): 3cd52b2

Delete models/denoiser/nextdit/layers.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/denoiser/nextdit/layers.py +0 -132
models/denoiser/nextdit/layers.py DELETED
@@ -1,132 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import numpy as np
5
- from typing import Callable, Optional
6
-
7
- import warnings
8
-
9
- import torch
10
- import torch.nn as nn
11
-
12
- try:
13
- from apex.normalization import FusedRMSNorm as RMSNorm
14
- except ImportError:
15
- warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
16
-
17
-
18
- class RMSNorm(torch.nn.Module):
19
- def __init__(self, dim: int, eps: float = 1e-6):
20
- """
21
- Initialize the RMSNorm normalization layer.
22
- Args:
23
- dim (int): The dimension of the input tensor.
24
- eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
25
- Attributes:
26
- eps (float): A small value added to the denominator for numerical stability.
27
- weight (nn.Parameter): Learnable scaling parameter.
28
- """
29
- super().__init__()
30
- self.eps = eps
31
- self.weight = nn.Parameter(torch.ones(dim))
32
-
33
- def _norm(self, x):
34
- """
35
- Apply the RMSNorm normalization to the input tensor.
36
- Args:
37
- x (torch.Tensor): The input tensor.
38
- Returns:
39
- torch.Tensor: The normalized tensor.
40
- """
41
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
42
-
43
- def forward(self, x):
44
- """
45
- Forward pass through the RMSNorm layer.
46
- Args:
47
- x (torch.Tensor): The input tensor.
48
- Returns:
49
- torch.Tensor: The output tensor after applying RMSNorm.
50
- """
51
- output = self._norm(x.float()).type_as(x)
52
- return output * self.weight
53
-
54
-
55
- def modulate(x, scale):
56
- return x * (1 + scale.unsqueeze(1))
57
-
58
- class LLamaFeedForward(nn.Module):
59
- """
60
- Corresponds to the FeedForward layer in Next DiT.
61
- """
62
- def __init__(
63
- self,
64
- dim: int,
65
- hidden_dim: int,
66
- multiple_of: int,
67
- ffn_dim_multiplier: Optional[float] = None,
68
- zeros_initialize: bool = True,
69
- dtype: torch.dtype = torch.float32,
70
- ):
71
- super().__init__()
72
- self.dim = dim
73
- self.hidden_dim = hidden_dim
74
- self.multiple_of = multiple_of
75
- self.ffn_dim_multiplier = ffn_dim_multiplier
76
- self.zeros_initialize = zeros_initialize
77
- self.dtype = dtype
78
-
79
- # Compute hidden_dim based on the given formula
80
- hidden_dim_calculated = int(2 * self.hidden_dim / 3)
81
- if self.ffn_dim_multiplier is not None:
82
- hidden_dim_calculated = int(self.ffn_dim_multiplier * hidden_dim_calculated)
83
- hidden_dim_calculated = self.multiple_of * ((hidden_dim_calculated + self.multiple_of - 1) // self.multiple_of)
84
-
85
- # Define linear layers
86
- self.w1 = nn.Linear(self.dim, hidden_dim_calculated, bias=False)
87
- self.w2 = nn.Linear(hidden_dim_calculated, self.dim, bias=False)
88
- self.w3 = nn.Linear(self.dim, hidden_dim_calculated, bias=False)
89
-
90
- # Initialize weights
91
- if self.zeros_initialize:
92
- nn.init.zeros_(self.w2.weight)
93
- else:
94
- nn.init.xavier_uniform_(self.w2.weight)
95
- nn.init.xavier_uniform_(self.w1.weight)
96
- nn.init.xavier_uniform_(self.w3.weight)
97
-
98
- def _forward_silu_gating(self, x1, x3):
99
- return F.silu(x1) * x3
100
-
101
- def forward(self, x):
102
- return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
103
-
104
- class FinalLayer(nn.Module):
105
- """
106
- The final layer of Next-DiT.
107
- """
108
- def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
109
- super().__init__()
110
- self.hidden_size = hidden_size
111
- self.patch_size = patch_size
112
- self.out_channels = out_channels
113
-
114
- # LayerNorm without learnable parameters (elementwise_affine=False)
115
- self.norm_final = nn.LayerNorm(self.hidden_size, eps=1e-6, elementwise_affine=False)
116
- self.linear = nn.Linear(self.hidden_size, np.prod(self.patch_size) * self.out_channels, bias=True)
117
- nn.init.zeros_(self.linear.weight)
118
- nn.init.zeros_(self.linear.bias)
119
-
120
- self.adaLN_modulation = nn.Sequential(
121
- nn.SiLU(),
122
- nn.Linear(self.hidden_size, self.hidden_size),
123
- )
124
- # Initialize the last layer with zeros
125
- nn.init.zeros_(self.adaLN_modulation[1].weight)
126
- nn.init.zeros_(self.adaLN_modulation[1].bias)
127
-
128
- def forward(self, x, c):
129
- scale = self.adaLN_modulation(c)
130
- x = modulate(self.norm_final(x), scale)
131
- x = self.linear(x)
132
- return x