Upload BD3LM
Browse files- modeling_bd3lm.py +11 -15
modeling_bd3lm.py
CHANGED
@@ -16,17 +16,15 @@ try:
|
|
16 |
FLEX_ATTN_AVAILABLE = True
|
17 |
except:
|
18 |
FLEX_ATTN_AVAILABLE = False
|
19 |
-
# Flags required to enable jit fusion kernels
|
20 |
-
try:
|
21 |
-
torch._C._jit_set_profiling_mode(False)
|
22 |
-
torch._C._jit_set_profiling_executor(False)
|
23 |
-
torch._C._jit_override_can_fuse_on_cpu(True)
|
24 |
-
torch._C._jit_override_fcan_fuse_on_gpu(True)
|
25 |
-
except:
|
26 |
-
pass
|
27 |
|
28 |
from .configuration_bd3lm import BD3LMConfig
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
|
31 |
"""
|
32 |
Constructs the specialized block diffusion attention mask for training
|
@@ -77,7 +75,6 @@ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
|
|
77 |
def fused_flex_attention(q, k, v, mask=None):
|
78 |
return flex_attention(q, k, v, block_mask=mask)
|
79 |
|
80 |
-
|
81 |
def bias_dropout_add_scale(
|
82 |
x: torch.Tensor,
|
83 |
bias: typing.Optional[torch.Tensor],
|
@@ -102,6 +99,7 @@ def get_bias_dropout_add_scale(training):
|
|
102 |
|
103 |
return _bias_dropout_add
|
104 |
|
|
|
105 |
# function overload
|
106 |
def modulate(x: torch.Tensor,
|
107 |
shift: torch.Tensor,
|
@@ -299,7 +297,7 @@ def regular_attention_multi_headed(qkv):
|
|
299 |
|
300 |
class DDiTBlock(nn.Module):
|
301 |
def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
|
302 |
-
dropout=0.1,
|
303 |
super().__init__()
|
304 |
self.n = n
|
305 |
self.block_size = block_size
|
@@ -394,11 +392,9 @@ class DDiTBlock(nn.Module):
|
|
394 |
else:
|
395 |
qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
|
396 |
|
397 |
-
if
|
398 |
-
x = regular_attention_multi_headed(qkv)
|
399 |
-
elif self.attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
|
400 |
x = self.cross_attn_flex(qkv, mask=mask)
|
401 |
-
elif self.attn_backend == 'sdpa':
|
402 |
x = self.cross_attn(x, qkv, mask=mask)
|
403 |
else:
|
404 |
raise ValueError('Unknown attention backend')
|
@@ -500,7 +496,7 @@ class DITBackbone(nn.Module):
|
|
500 |
self.mask = create_block_mask(
|
501 |
partial(block_diff_mask, block_size=block_size, n=seqlen),
|
502 |
B=None, H=None, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
|
503 |
-
elif attn_backend == 'sdpa':
|
504 |
self.mask = block_diff_mask(
|
505 |
b=None, h=None, q_idx=torch.arange(seqlen*2)[:, None],
|
506 |
kv_idx=torch.arange(seqlen*2)[None, :], block_size=block_size, n=seqlen)
|
|
|
16 |
FLEX_ATTN_AVAILABLE = True
|
17 |
except:
|
18 |
FLEX_ATTN_AVAILABLE = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
from .configuration_bd3lm import BD3LMConfig
|
21 |
|
22 |
+
# Flags required to enable jit fusion kernels
|
23 |
+
torch._C._jit_set_profiling_mode(False)
|
24 |
+
torch._C._jit_set_profiling_executor(False)
|
25 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
26 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
27 |
+
|
28 |
def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
|
29 |
"""
|
30 |
Constructs the specialized block diffusion attention mask for training
|
|
|
75 |
def fused_flex_attention(q, k, v, mask=None):
|
76 |
return flex_attention(q, k, v, block_mask=mask)
|
77 |
|
|
|
78 |
def bias_dropout_add_scale(
|
79 |
x: torch.Tensor,
|
80 |
bias: typing.Optional[torch.Tensor],
|
|
|
99 |
|
100 |
return _bias_dropout_add
|
101 |
|
102 |
+
|
103 |
# function overload
|
104 |
def modulate(x: torch.Tensor,
|
105 |
shift: torch.Tensor,
|
|
|
297 |
|
298 |
class DDiTBlock(nn.Module):
|
299 |
def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
|
300 |
+
dropout=0.1, attn_backend='sdpa'):
|
301 |
super().__init__()
|
302 |
self.n = n
|
303 |
self.block_size = block_size
|
|
|
392 |
else:
|
393 |
qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
|
394 |
|
395 |
+
if self.attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
|
|
|
|
|
396 |
x = self.cross_attn_flex(qkv, mask=mask)
|
397 |
+
elif self.attn_backend == 'sdpa' or not FLEX_ATTN_AVAILABLE:
|
398 |
x = self.cross_attn(x, qkv, mask=mask)
|
399 |
else:
|
400 |
raise ValueError('Unknown attention backend')
|
|
|
496 |
self.mask = create_block_mask(
|
497 |
partial(block_diff_mask, block_size=block_size, n=seqlen),
|
498 |
B=None, H=None, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
|
499 |
+
elif attn_backend == 'sdpa' or not FLEX_ATTN_AVAILABLE:
|
500 |
self.mask = block_diff_mask(
|
501 |
b=None, h=None, q_idx=torch.arange(seqlen*2)[:, None],
|
502 |
kv_idx=torch.arange(seqlen*2)[None, :], block_size=block_size, n=seqlen)
|