marriola commited on
Commit
d1a5e7a
·
verified ·
1 Parent(s): c0c9f3c

Upload BD3LM

Browse files
Files changed (1) hide show
  1. modeling_bd3lm.py +8 -8
modeling_bd3lm.py CHANGED
@@ -396,10 +396,10 @@ class DDiTBlock(nn.Module):
396
 
397
  if mask is None and self.attn_backend == 'flash_attn':
398
  x = regular_attention_multi_headed(qkv)
399
- elif self.attn_backend == 'sdpa':
400
- x = self.cross_attn(x, qkv, mask=mask)
401
  elif self.attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
402
  x = self.cross_attn_flex(qkv, mask=mask)
 
 
403
  else:
404
  raise ValueError('Unknown attention backend')
405
 
@@ -485,7 +485,7 @@ class DITBackbone(nn.Module):
485
  config.vocab_size,
486
  config.cond_dim)
487
  if self.cross_attn:
488
- self.gen_mask(config.model_length, self.block_size)
489
  self.precision = torch.float32
490
 
491
  def _get_bias_dropout_scale(self):
@@ -496,14 +496,14 @@ class DITBackbone(nn.Module):
496
 
497
  def gen_mask(self, seqlen, block_size, attn_backend='sdpa'):
498
  """Genererates attention mask"""
499
- if attn_backend == 'sdpa':
500
- self.mask = block_diff_mask(
501
- b=None, h=None, q_idx=torch.arange(seqlen*2)[:, None], kv_idx=torch.arange(seqlen*2)[None, :],
502
- block_size=block_size, n=seqlen)
503
- elif attn_backend == 'flex':
504
  self.mask = create_block_mask(
505
  partial(block_diff_mask, block_size=block_size, n=seqlen),
506
  B=None, H=None, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
 
 
 
 
507
  else:
508
  raise ValueError('Unknown attention backend')
509
 
 
396
 
397
  if mask is None and self.attn_backend == 'flash_attn':
398
  x = regular_attention_multi_headed(qkv)
 
 
399
  elif self.attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
400
  x = self.cross_attn_flex(qkv, mask=mask)
401
+ elif self.attn_backend == 'sdpa':
402
+ x = self.cross_attn(x, qkv, mask=mask)
403
  else:
404
  raise ValueError('Unknown attention backend')
405
 
 
485
  config.vocab_size,
486
  config.cond_dim)
487
  if self.cross_attn:
488
+ self.gen_mask(config.model_length, self.block_size, attn_backend=config.attn_backend)
489
  self.precision = torch.float32
490
 
491
  def _get_bias_dropout_scale(self):
 
496
 
497
  def gen_mask(self, seqlen, block_size, attn_backend='sdpa'):
498
  """Genererates attention mask"""
499
+ if attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
 
 
 
 
500
  self.mask = create_block_mask(
501
  partial(block_diff_mask, block_size=block_size, n=seqlen),
502
  B=None, H=None, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
503
+ elif attn_backend == 'sdpa':
504
+ self.mask = block_diff_mask(
505
+ b=None, h=None, q_idx=torch.arange(seqlen*2)[:, None],
506
+ kv_idx=torch.arange(seqlen*2)[None, :], block_size=block_size, n=seqlen)
507
  else:
508
  raise ValueError('Unknown attention backend')
509