Update flash_attention.py
Browse files- flash_attention.py +3 -3
flash_attention.py
CHANGED
@@ -66,7 +66,7 @@ class FlashAttention(nn.Module):
|
|
66 |
max_s,
|
67 |
self.dropout_p if self.training else 0.0,
|
68 |
self.softmax_scale,
|
69 |
-
causal
|
70 |
)
|
71 |
output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
|
72 |
else:
|
@@ -82,7 +82,7 @@ class FlashAttention(nn.Module):
|
|
82 |
max_s,
|
83 |
self.dropout_p if self.training else 0.0,
|
84 |
self.softmax_scale,
|
85 |
-
causal
|
86 |
)
|
87 |
output = rearrange(
|
88 |
pad_input(
|
@@ -102,7 +102,7 @@ class FlashAttention(nn.Module):
|
|
102 |
max_s,
|
103 |
self.dropout_p if self.training else 0.0,
|
104 |
self.softmax_scale,
|
105 |
-
causal
|
106 |
)
|
107 |
|
108 |
return output, None
|
|
|
66 |
max_s,
|
67 |
self.dropout_p if self.training else 0.0,
|
68 |
self.softmax_scale,
|
69 |
+
causal
|
70 |
)
|
71 |
output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
|
72 |
else:
|
|
|
82 |
max_s,
|
83 |
self.dropout_p if self.training else 0.0,
|
84 |
self.softmax_scale,
|
85 |
+
causal
|
86 |
)
|
87 |
output = rearrange(
|
88 |
pad_input(
|
|
|
102 |
max_s,
|
103 |
self.dropout_p if self.training else 0.0,
|
104 |
self.softmax_scale,
|
105 |
+
causal
|
106 |
)
|
107 |
|
108 |
return output, None
|