monsoon-nlp commited on
Commit
4040262
·
verified ·
1 Parent(s): 532cc30
Files changed (1) hide show
  1. modeling_bd3lm.py +1 -0
modeling_bd3lm.py CHANGED
@@ -71,6 +71,7 @@ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
71
  # **4. Combine Masks **
72
  return block_diagonal | offset_block_causal | block_causal
73
 
 
74
  def fused_flex_attention(q, k, v, mask=None):
75
  return flex_attention(q, k, v, block_mask=mask)
76
 
 
71
  # **4. Combine Masks **
72
  return block_diagonal | offset_block_causal | block_causal
73
 
74
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
75
  def fused_flex_attention(q, k, v, mask=None):
76
  return flex_attention(q, k, v, block_mask=mask)
77