NIRVANALAN commited on
Commit
f1d83ba
1 Parent(s): cf99ccb

update dep

Browse files
Files changed (4) hide show
  1. app.py +7 -7
  2. dit/dit_models_xformers.py +6 -2
  3. dit/norm.py +18 -0
  4. ldm/modules/attention.py +5 -4
app.py CHANGED
@@ -32,18 +32,18 @@ import numpy as np
32
  import torch as th
33
  import torch.distributed as dist
34
 
35
- def install_dependency():
36
- # install apex
37
- subprocess.run(
38
- f'FORCE_CUDA=1 {sys.executable} -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" git+https://github.com/NVIDIA/apex.git@master',
39
- shell=True,
40
- )
41
 
42
  th.backends.cuda.matmul.allow_tf32 = True
43
  th.backends.cudnn.allow_tf32 = True
44
  th.backends.cudnn.enabled = True
45
 
46
- install_dependency()
47
 
48
  from guided_diffusion import dist_util, logger
49
  from guided_diffusion.script_util import (
 
32
  import torch as th
33
  import torch.distributed as dist
34
 
35
+ # def install_dependency():
36
+ # # install apex
37
+ # subprocess.run(
38
+ # f'FORCE_CUDA=1 {sys.executable} -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" git+https://github.com/NVIDIA/apex.git@master',
39
+ # shell=True,
40
+ # )
41
 
42
  th.backends.cuda.matmul.allow_tf32 = True
43
  th.backends.cudnn.allow_tf32 = True
44
  th.backends.cudnn.enabled = True
45
 
46
+ # install_dependency()
47
 
48
  from guided_diffusion import dist_util, logger
49
  from guided_diffusion.script_util import (
dit/dit_models_xformers.py CHANGED
@@ -24,8 +24,12 @@ from pdb import set_trace as st
24
  from ldm.modules.attention import CrossAttention
25
  from vit.vision_transformer import MemEffAttention as Attention
26
  # import apex
27
- from apex.normalization import FusedRMSNorm as RMSNorm
28
- from apex.normalization import FusedLayerNorm as LayerNorm
 
 
 
 
29
 
30
  # from torch.nn import LayerNorm
31
  # from xformers import triton
 
24
  from ldm.modules.attention import CrossAttention
25
  from vit.vision_transformer import MemEffAttention as Attention
26
  # import apex
27
+ try:
28
+ from apex.normalization import FusedRMSNorm as RMSNorm
29
+ from apex.normalization import FusedLayerNorm as LayerNorm
30
+ except:
31
+ from torch.nn import LayerNorm as LayerNorm
32
+ from .norm import RMSNorm
33
 
34
  # from torch.nn import LayerNorm
35
  # from xformers import triton
dit/norm.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def rms_norm(x, weight=None, eps=1e-05):
4
+ output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
5
+ return output * weight if weight is not None else output
6
+
7
+ class RMSNorm(torch.nn.Module):
8
+
9
+ def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
10
+ super().__init__()
11
+ self.eps = eps
12
+ if weight:
13
+ self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
14
+ else:
15
+ self.register_parameter('weight', None)
16
+
17
+ def forward(self, x):
18
+ return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
ldm/modules/attention.py CHANGED
@@ -7,16 +7,17 @@ from einops import rearrange, repeat
7
  from pdb import set_trace as st
8
 
9
  from ldm.modules.diffusionmodules.util import checkpoint
10
- from apex.normalization import FusedLayerNorm as LayerNorm
11
 
12
 
13
  # CrossAttn precision handling
14
  import os
15
  _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
16
  from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
17
- # from xformers.ops import RMSNorm, fmha, rope_padded
18
- # import apex
19
- from apex.normalization import FusedRMSNorm as RMSNorm
 
20
 
21
 
22
  def exists(val):
 
7
  from pdb import set_trace as st
8
 
9
  from ldm.modules.diffusionmodules.util import checkpoint
10
+ # from apex.normalization import FusedLayerNorm as LayerNorm
11
 
12
 
13
  # CrossAttn precision handling
14
  import os
15
  _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
16
  from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
17
+ try:
18
+ from apex.normalization import FusedRMSNorm as RMSNorm
19
+ except:
20
+ from dit.norm import RMSNorm
21
 
22
 
23
  def exists(val):