yslan commited on
Commit
2843133
·
1 Parent(s): 7c949c9
Files changed (1) hide show
  1. ldm/modules/attention.py +39 -38
ldm/modules/attention.py CHANGED
@@ -26,40 +26,40 @@ from xformers.ops import MemoryEfficientAttentionFlashAttentionOp, MemoryEfficie
26
  # import apex
27
  # from apex.normalization import FusedRMSNorm as RMSNorm
28
 
29
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
30
- SDP_IS_AVAILABLE = True
31
- # from torch.backends.cuda import SDPBackend, sdp_kernel
32
- from torch.nn.attention import sdpa_kernel, SDPBackend
33
-
34
- BACKEND_MAP = {
35
- SDPBackend.MATH: {
36
- "enable_math": True,
37
- "enable_flash": False,
38
- "enable_mem_efficient": False,
39
- },
40
- SDPBackend.FLASH_ATTENTION: {
41
- "enable_math": False,
42
- "enable_flash": True,
43
- "enable_mem_efficient": False,
44
- },
45
- SDPBackend.EFFICIENT_ATTENTION: {
46
- "enable_math": False,
47
- "enable_flash": False,
48
- "enable_mem_efficient": True,
49
- },
50
- None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
51
- }
52
- else:
53
- from contextlib import nullcontext
54
-
55
- SDP_IS_AVAILABLE = False
56
- sdpa_kernel = nullcontext
57
- BACKEND_MAP = {}
58
- logpy.warn(
59
- f"No SDP backend available, likely because you are running in pytorch "
60
- f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
61
- f"You might want to consider upgrading."
62
- )
63
 
64
 
65
  def exists(val):
@@ -282,11 +282,12 @@ class CrossAttention(nn.Module):
282
  """
283
  ## new
284
  # with sdpa_kernel(**BACKEND_MAP[self.backend]):
285
- with sdpa_kernel([self.backend]): # new signature
286
  # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
287
- out = F.scaled_dot_product_attention(
288
- q, k, v, attn_mask=mask
289
- ) # scale is dim_head ** -0.5 per default
 
290
 
291
  del q, k, v
292
  out = rearrange(out, "b h n d -> b n (h d)", h=h)
 
26
  # import apex
27
  # from apex.normalization import FusedRMSNorm as RMSNorm
28
 
29
+ # if version.parse(torch.__version__) >= version.parse("2.0.0"):
30
+ # SDP_IS_AVAILABLE = True
31
+ # # from torch.backends.cuda import SDPBackend, sdp_kernel
32
+ # from torch.nn.attention import sdpa_kernel, SDPBackend
33
+
34
+ # BACKEND_MAP = {
35
+ # SDPBackend.MATH: {
36
+ # "enable_math": True,
37
+ # "enable_flash": False,
38
+ # "enable_mem_efficient": False,
39
+ # },
40
+ # SDPBackend.FLASH_ATTENTION: {
41
+ # "enable_math": False,
42
+ # "enable_flash": True,
43
+ # "enable_mem_efficient": False,
44
+ # },
45
+ # SDPBackend.EFFICIENT_ATTENTION: {
46
+ # "enable_math": False,
47
+ # "enable_flash": False,
48
+ # "enable_mem_efficient": True,
49
+ # },
50
+ # None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
51
+ # }
52
+ # else:
53
+ # from contextlib import nullcontext
54
+
55
+ SDP_IS_AVAILABLE = False
56
+ # sdpa_kernel = nullcontext
57
+ # BACKEND_MAP = {}
58
+ # logpy.warn(
59
+ # f"No SDP backend available, likely because you are running in pytorch "
60
+ # f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
61
+ # f"You might want to consider upgrading."
62
+ # )
63
 
64
 
65
  def exists(val):
 
282
  """
283
  ## new
284
  # with sdpa_kernel(**BACKEND_MAP[self.backend]):
285
+ # with sdpa_kernel([self.backend]): # new signature
286
  # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
287
+
288
+ out = F.scaled_dot_product_attention(
289
+ q, k, v, attn_mask=mask
290
+ ) # scale is dim_head ** -0.5 per default
291
 
292
  del q, k, v
293
  out = rearrange(out, "b h n d -> b n (h d)", h=h)