Update transformer/transformer.py
Browse files
transformer/transformer.py
CHANGED
@@ -100,7 +100,7 @@ class RMSNorm(torch.nn.Module):
|
|
100 |
|
101 |
|
102 |
# Modified from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/pixart_transformer_2d.py
|
103 |
-
class
|
104 |
_supports_gradient_checkpointing = True
|
105 |
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
|
106 |
|
|
|
100 |
|
101 |
|
102 |
# Modified from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/pixart_transformer_2d.py
|
103 |
+
class NitroDiTModel(ModelMixin, ConfigMixin):
|
104 |
_supports_gradient_checkpointing = True
|
105 |
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
|
106 |
|