NIRVANALAN commited on
Commit
76e9ff7
·
1 Parent(s): e0ba37d
dit/dit_i23d.py CHANGED
@@ -16,7 +16,7 @@ try:
16
  from apex.normalization import FusedRMSNorm as RMSNorm
17
  except:
18
  from torch.nn import LayerNorm
19
- from dit.norm import RMSNorm
20
 
21
  # from vit.vit_triplane import XYZPosEmbed
22
 
 
16
  from apex.normalization import FusedRMSNorm as RMSNorm
17
  except:
18
  from torch.nn import LayerNorm
19
+ from diffusers.models.normalization import RMSNorm
20
 
21
  # from vit.vit_triplane import XYZPosEmbed
22
 
dit/dit_models_xformers.py CHANGED
@@ -29,8 +29,7 @@ try:
29
  from apex.normalization import FusedLayerNorm as LayerNorm
30
  except:
31
  from torch.nn import LayerNorm
32
- # from torch.nn import RMSNorm # requires torch2.4
33
- from dit.norm import RMSNorm
34
 
35
  # from torch.nn import LayerNorm
36
  # from xformers import triton
 
29
  from apex.normalization import FusedLayerNorm as LayerNorm
30
  except:
31
  from torch.nn import LayerNorm
32
+ from diffusers.models.normalization import RMSNorm
 
33
 
34
  # from torch.nn import LayerNorm
35
  # from xformers import triton
dit/norm.py DELETED
@@ -1,18 +0,0 @@
1
- import torch
2
-
3
- def rms_norm(x, weight=None, eps=1e-05):
4
- output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
5
- return output * weight if weight is not None else output
6
-
7
- class RMSNorm(torch.nn.Module):
8
-
9
- def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=None, device=None):
10
- super().__init__()
11
- self.eps = eps
12
- if elementwise_affine:
13
- self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
14
- else:
15
- self.register_parameter('weight', None)
16
-
17
- def forward(self, x):
18
- return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -31,4 +31,5 @@ safetensors
31
  matplotlib
32
  git+https://github.com/nupurkmr9/vision-aided-gan
33
  PyMCubes
34
- trimesh
 
 
31
  matplotlib
32
  git+https://github.com/nupurkmr9/vision-aided-gan
33
  PyMCubes
34
+ trimesh
35
+ diffusers