Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- ldm/modules/attention.py +39 -38
ldm/modules/attention.py
CHANGED
@@ -26,40 +26,40 @@ from xformers.ops import MemoryEfficientAttentionFlashAttentionOp, MemoryEfficie
|
|
26 |
# import apex
|
27 |
# from apex.normalization import FusedRMSNorm as RMSNorm
|
28 |
|
29 |
-
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
else:
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
|
64 |
|
65 |
def exists(val):
|
@@ -282,11 +282,12 @@ class CrossAttention(nn.Module):
|
|
282 |
"""
|
283 |
## new
|
284 |
# with sdpa_kernel(**BACKEND_MAP[self.backend]):
|
285 |
-
with sdpa_kernel([self.backend]): # new signature
|
286 |
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
287 |
-
|
288 |
-
|
289 |
-
|
|
|
290 |
|
291 |
del q, k, v
|
292 |
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
|
|
26 |
# import apex
|
27 |
# from apex.normalization import FusedRMSNorm as RMSNorm
|
28 |
|
29 |
+
# if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
30 |
+
# SDP_IS_AVAILABLE = True
|
31 |
+
# # from torch.backends.cuda import SDPBackend, sdp_kernel
|
32 |
+
# from torch.nn.attention import sdpa_kernel, SDPBackend
|
33 |
+
|
34 |
+
# BACKEND_MAP = {
|
35 |
+
# SDPBackend.MATH: {
|
36 |
+
# "enable_math": True,
|
37 |
+
# "enable_flash": False,
|
38 |
+
# "enable_mem_efficient": False,
|
39 |
+
# },
|
40 |
+
# SDPBackend.FLASH_ATTENTION: {
|
41 |
+
# "enable_math": False,
|
42 |
+
# "enable_flash": True,
|
43 |
+
# "enable_mem_efficient": False,
|
44 |
+
# },
|
45 |
+
# SDPBackend.EFFICIENT_ATTENTION: {
|
46 |
+
# "enable_math": False,
|
47 |
+
# "enable_flash": False,
|
48 |
+
# "enable_mem_efficient": True,
|
49 |
+
# },
|
50 |
+
# None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
|
51 |
+
# }
|
52 |
+
# else:
|
53 |
+
# from contextlib import nullcontext
|
54 |
+
|
55 |
+
SDP_IS_AVAILABLE = False
|
56 |
+
# sdpa_kernel = nullcontext
|
57 |
+
# BACKEND_MAP = {}
|
58 |
+
# logpy.warn(
|
59 |
+
# f"No SDP backend available, likely because you are running in pytorch "
|
60 |
+
# f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
|
61 |
+
# f"You might want to consider upgrading."
|
62 |
+
# )
|
63 |
|
64 |
|
65 |
def exists(val):
|
|
|
282 |
"""
|
283 |
## new
|
284 |
# with sdpa_kernel(**BACKEND_MAP[self.backend]):
|
285 |
+
# with sdpa_kernel([self.backend]): # new signature
|
286 |
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
287 |
+
|
288 |
+
out = F.scaled_dot_product_attention(
|
289 |
+
q, k, v, attn_mask=mask
|
290 |
+
) # scale is dim_head ** -0.5 per default
|
291 |
|
292 |
del q, k, v
|
293 |
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|