revert
Browse files- modeling_bd3lm.py +1 -0
modeling_bd3lm.py
CHANGED
@@ -71,6 +71,7 @@ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
|
|
71 |
# **4. Combine Masks **
|
72 |
return block_diagonal | offset_block_causal | block_causal
|
73 |
|
|
|
74 |
def fused_flex_attention(q, k, v, mask=None):
|
75 |
return flex_attention(q, k, v, block_mask=mask)
|
76 |
|
|
|
71 |
# **4. Combine Masks **
|
72 |
return block_diagonal | offset_block_causal | block_causal
|
73 |
|
74 |
+
@torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
|
75 |
def fused_flex_attention(q, k, v, mask=None):
|
76 |
return flex_attention(q, k, v, block_mask=mask)
|
77 |
|