Upload BD3LM
Browse files- modeling_bd3lm.py +8 -8
modeling_bd3lm.py
CHANGED
@@ -396,10 +396,10 @@ class DDiTBlock(nn.Module):
|
|
396 |
|
397 |
if mask is None and self.attn_backend == 'flash_attn':
|
398 |
x = regular_attention_multi_headed(qkv)
|
399 |
-
elif self.attn_backend == 'sdpa':
|
400 |
-
x = self.cross_attn(x, qkv, mask=mask)
|
401 |
elif self.attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
|
402 |
x = self.cross_attn_flex(qkv, mask=mask)
|
|
|
|
|
403 |
else:
|
404 |
raise ValueError('Unknown attention backend')
|
405 |
|
@@ -485,7 +485,7 @@ class DITBackbone(nn.Module):
|
|
485 |
config.vocab_size,
|
486 |
config.cond_dim)
|
487 |
if self.cross_attn:
|
488 |
-
self.gen_mask(config.model_length, self.block_size)
|
489 |
self.precision = torch.float32
|
490 |
|
491 |
def _get_bias_dropout_scale(self):
|
@@ -496,14 +496,14 @@ class DITBackbone(nn.Module):
|
|
496 |
|
497 |
def gen_mask(self, seqlen, block_size, attn_backend='sdpa'):
|
498 |
"""Genererates attention mask"""
|
499 |
-
if attn_backend == '
|
500 |
-
self.mask = block_diff_mask(
|
501 |
-
b=None, h=None, q_idx=torch.arange(seqlen*2)[:, None], kv_idx=torch.arange(seqlen*2)[None, :],
|
502 |
-
block_size=block_size, n=seqlen)
|
503 |
-
elif attn_backend == 'flex':
|
504 |
self.mask = create_block_mask(
|
505 |
partial(block_diff_mask, block_size=block_size, n=seqlen),
|
506 |
B=None, H=None, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
|
|
|
|
|
|
|
|
|
507 |
else:
|
508 |
raise ValueError('Unknown attention backend')
|
509 |
|
|
|
396 |
|
397 |
if mask is None and self.attn_backend == 'flash_attn':
|
398 |
x = regular_attention_multi_headed(qkv)
|
|
|
|
|
399 |
elif self.attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
|
400 |
x = self.cross_attn_flex(qkv, mask=mask)
|
401 |
+
elif self.attn_backend == 'sdpa':
|
402 |
+
x = self.cross_attn(x, qkv, mask=mask)
|
403 |
else:
|
404 |
raise ValueError('Unknown attention backend')
|
405 |
|
|
|
485 |
config.vocab_size,
|
486 |
config.cond_dim)
|
487 |
if self.cross_attn:
|
488 |
+
self.gen_mask(config.model_length, self.block_size, attn_backend=config.attn_backend)
|
489 |
self.precision = torch.float32
|
490 |
|
491 |
def _get_bias_dropout_scale(self):
|
|
|
496 |
|
497 |
def gen_mask(self, seqlen, block_size, attn_backend='sdpa'):
|
498 |
"""Genererates attention mask"""
|
499 |
+
if attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
|
|
|
|
|
|
|
|
|
500 |
self.mask = create_block_mask(
|
501 |
partial(block_diff_mask, block_size=block_size, n=seqlen),
|
502 |
B=None, H=None, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
|
503 |
+
elif attn_backend == 'sdpa':
|
504 |
+
self.mask = block_diff_mask(
|
505 |
+
b=None, h=None, q_idx=torch.arange(seqlen*2)[:, None],
|
506 |
+
kv_idx=torch.arange(seqlen*2)[None, :], block_size=block_size, n=seqlen)
|
507 |
else:
|
508 |
raise ValueError('Unknown attention backend')
|
509 |
|