marriola commited on
Commit
5cc765d
·
verified ·
1 Parent(s): 4e01d79

Upload BD3LM

Browse files
Files changed (1) hide show
  1. modeling_bd3lm.py +11 -15
modeling_bd3lm.py CHANGED
@@ -16,17 +16,15 @@ try:
16
  FLEX_ATTN_AVAILABLE = True
17
  except:
18
  FLEX_ATTN_AVAILABLE = False
19
- # Flags required to enable jit fusion kernels
20
- try:
21
- torch._C._jit_set_profiling_mode(False)
22
- torch._C._jit_set_profiling_executor(False)
23
- torch._C._jit_override_can_fuse_on_cpu(True)
24
- torch._C._jit_override_fcan_fuse_on_gpu(True)
25
- except:
26
- pass
27
 
28
  from .configuration_bd3lm import BD3LMConfig
29
 
 
 
 
 
 
 
30
  def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
31
  """
32
  Constructs the specialized block diffusion attention mask for training
@@ -77,7 +75,6 @@ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
77
  def fused_flex_attention(q, k, v, mask=None):
78
  return flex_attention(q, k, v, block_mask=mask)
79
 
80
-
81
  def bias_dropout_add_scale(
82
  x: torch.Tensor,
83
  bias: typing.Optional[torch.Tensor],
@@ -102,6 +99,7 @@ def get_bias_dropout_add_scale(training):
102
 
103
  return _bias_dropout_add
104
 
 
105
  # function overload
106
  def modulate(x: torch.Tensor,
107
  shift: torch.Tensor,
@@ -299,7 +297,7 @@ def regular_attention_multi_headed(qkv):
299
 
300
  class DDiTBlock(nn.Module):
301
  def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
302
- dropout=0.1, max_seqlen=1024, attn_backend='flash_attn'):
303
  super().__init__()
304
  self.n = n
305
  self.block_size = block_size
@@ -394,11 +392,9 @@ class DDiTBlock(nn.Module):
394
  else:
395
  qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
396
 
397
- if mask is None and self.attn_backend == 'flash_attn':
398
- x = regular_attention_multi_headed(qkv)
399
- elif self.attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
400
  x = self.cross_attn_flex(qkv, mask=mask)
401
- elif self.attn_backend == 'sdpa':
402
  x = self.cross_attn(x, qkv, mask=mask)
403
  else:
404
  raise ValueError('Unknown attention backend')
@@ -500,7 +496,7 @@ class DITBackbone(nn.Module):
500
  self.mask = create_block_mask(
501
  partial(block_diff_mask, block_size=block_size, n=seqlen),
502
  B=None, H=None, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
503
- elif attn_backend == 'sdpa':
504
  self.mask = block_diff_mask(
505
  b=None, h=None, q_idx=torch.arange(seqlen*2)[:, None],
506
  kv_idx=torch.arange(seqlen*2)[None, :], block_size=block_size, n=seqlen)
 
16
  FLEX_ATTN_AVAILABLE = True
17
  except:
18
  FLEX_ATTN_AVAILABLE = False
 
 
 
 
 
 
 
 
19
 
20
  from .configuration_bd3lm import BD3LMConfig
21
 
22
+ # Flags required to enable jit fusion kernels
23
+ torch._C._jit_set_profiling_mode(False)
24
+ torch._C._jit_set_profiling_executor(False)
25
+ torch._C._jit_override_can_fuse_on_cpu(True)
26
+ torch._C._jit_override_can_fuse_on_gpu(True)
27
+
28
  def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
29
  """
30
  Constructs the specialized block diffusion attention mask for training
 
75
  def fused_flex_attention(q, k, v, mask=None):
76
  return flex_attention(q, k, v, block_mask=mask)
77
 
 
78
  def bias_dropout_add_scale(
79
  x: torch.Tensor,
80
  bias: typing.Optional[torch.Tensor],
 
99
 
100
  return _bias_dropout_add
101
 
102
+
103
  # function overload
104
  def modulate(x: torch.Tensor,
105
  shift: torch.Tensor,
 
297
 
298
  class DDiTBlock(nn.Module):
299
  def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
300
+ dropout=0.1, attn_backend='sdpa'):
301
  super().__init__()
302
  self.n = n
303
  self.block_size = block_size
 
392
  else:
393
  qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
394
 
395
+ if self.attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
 
 
396
  x = self.cross_attn_flex(qkv, mask=mask)
397
+ elif self.attn_backend == 'sdpa' or not FLEX_ATTN_AVAILABLE:
398
  x = self.cross_attn(x, qkv, mask=mask)
399
  else:
400
  raise ValueError('Unknown attention backend')
 
496
  self.mask = create_block_mask(
497
  partial(block_diff_mask, block_size=block_size, n=seqlen),
498
  B=None, H=None, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
499
+ elif attn_backend == 'sdpa' or not FLEX_ATTN_AVAILABLE:
500
  self.mask = block_diff_mask(
501
  b=None, h=None, q_idx=torch.arange(seqlen*2)[:, None],
502
  kv_idx=torch.arange(seqlen*2)[None, :], block_size=block_size, n=seqlen)