marriola commited on
Commit
66f2d23
·
verified ·
1 Parent(s): 6f2022f

Upload BD3LM

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. configuration_bd3lm.py +1 -1
  3. modeling_bd3lm.py +97 -53
config.json CHANGED
@@ -3,7 +3,7 @@
3
  "architectures": [
4
  "BD3LM"
5
  ],
6
- "attn_backend": "sdpa",
7
  "auto_map": {
8
  "AutoConfig": "configuration_bd3lm.BD3LMConfig",
9
  "AutoModelForMaskedLM": "modeling_bd3lm.BD3LM"
 
3
  "architectures": [
4
  "BD3LM"
5
  ],
6
+ "attn_backend": "flex",
7
  "auto_map": {
8
  "AutoConfig": "configuration_bd3lm.BD3LMConfig",
9
  "AutoModelForMaskedLM": "modeling_bd3lm.BD3LM"
configuration_bd3lm.py CHANGED
@@ -15,7 +15,7 @@ class BD3LMConfig(transformers.PretrainedConfig):
15
  vocab_size: int = 50258,
16
  model_length: int = 1024,
17
  cross_attn: bool = True,
18
- attn_backend: str = 'sdpa',
19
  hidden_dim: int = 768,
20
  cond_dim: int = 129,
21
  n_blocks: int = 12,
 
15
  vocab_size: int = 50258,
16
  model_length: int = 1024,
17
  cross_attn: bool = True,
18
+ attn_backend: str = 'flex',
19
  hidden_dim: int = 768,
20
  cond_dim: int = 129,
21
  n_blocks: int = 12,
modeling_bd3lm.py CHANGED
@@ -5,13 +5,17 @@ import math
5
  import typing
6
 
7
  import einops
8
- import flash_attn
9
- import flash_attn.layers.rotary
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
  import transformers
14
  from transformers import modeling_outputs
 
 
 
 
 
15
 
16
  from .configuration_bd3lm import BD3LMConfig
17
 
@@ -21,21 +25,55 @@ torch._C._jit_set_profiling_executor(False)
21
  torch._C._jit_override_can_fuse_on_cpu(True)
22
  torch._C._jit_override_can_fuse_on_gpu(True)
23
 
24
- def block_causal_mask(num_rows, block_size, mode='full', offset=0):
25
- mask = block_size * torch.arange(
26
- 1, num_rows // block_size + 1).unsqueeze(1).tile(block_size).flatten().unsqueeze(1)
27
- if mode == 'full':
28
- mask = (mask >= mask.T + offset)
29
- elif mode == 'diag':
30
- mask = (mask + offset == mask.T)
31
- elif mode == 'triu_diag':
32
- mask = torch.zeros(num_rows, num_rows)
33
- rows = torch.arange(0, num_rows)
34
- group_indices = rows // (block_size)
35
- column_indices = group_indices * (block_size) + block_size + offset
36
- valid_rows = column_indices < num_rows
37
- mask[rows[valid_rows].unsqueeze(1), column_indices[valid_rows].unsqueeze(1)] = 1
38
- return mask.int()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def bias_dropout_add_scale(
41
  x: torch.Tensor,
@@ -132,12 +170,6 @@ def rotate_half(x):
132
  def apply_rotary_pos_emb_torchscript(qkv, cos, sin):
133
  return (qkv * cos) + (rotate_half(qkv) * sin)
134
 
135
- def apply_rotary_pos_emb(qkv, cos, sin):
136
- cos = cos[0,:,0,0,:cos.shape[-1]//2]
137
- sin = sin[0,:,0,0,:sin.shape[-1]//2]
138
- return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)
139
-
140
-
141
  # function overload
142
  def modulate(x, shift, scale):
143
  return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
@@ -317,32 +349,33 @@ class DDiTBlock(nn.Module):
317
  h=self.n_heads)
318
  with torch.cuda.amp.autocast(enabled=False):
319
  cos, sin = rotary_cos_sin
320
- if self.attn_backend == 'flash_attn':
321
- qkv = apply_rotary_pos_emb(
322
- qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
323
- else:
324
- qkv = apply_rotary_pos_emb_torchscript(
325
- qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
326
  return qkv
327
 
328
- def cross_attn(self, x, qkv, cross_attn_mask=None):
329
  scale = qkv.shape[-1]
330
  qkv = qkv.transpose(1, 3)
331
- attn_dropout = self.attn_dropout if self.training else 0.0
332
- cross_attn_mask = cross_attn_mask.bool() if cross_attn_mask is not None else None
333
  x = F.scaled_dot_product_attention(
334
  query=qkv[:, :, 0],
335
  key=qkv[:, :, 1],
336
  value=qkv[:, :, 2],
337
- attn_mask=cross_attn_mask,
338
- dropout_p=attn_dropout,
339
  is_causal=False,
340
  scale=1 / math.sqrt(scale))
341
  x = x.transpose(1, 2)
342
  x = einops.rearrange(x, 'b s h d -> b s (h d)')
343
  return x
344
-
345
- def forward(self, x, rotary_cos_sin, c, cross_attn_mask=None,
 
 
 
 
 
 
 
346
  sample_mode=False, store_kv=False):
347
  bias_dropout_scale_fn = self._get_bias_dropout_scale()
348
 
@@ -354,17 +387,21 @@ class DDiTBlock(nn.Module):
354
  x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
355
 
356
  # get qkvs
357
- if cross_attn_mask is not None and not sample_mode:
358
  qkv_x = self.get_qkv(x[:,:self.n], rotary_cos_sin)
359
  qkv_x0 = self.get_qkv(x[:,self.n:], rotary_cos_sin)
360
  qkv = torch.cat((qkv_x, qkv_x0), dim=1)
361
  else:
362
  qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
363
 
364
- if cross_attn_mask is None and self.attn_backend == 'flash_attn':
365
  x = regular_attention_multi_headed(qkv)
 
 
 
 
366
  else:
367
- x = self.cross_attn(x, qkv, cross_attn_mask=cross_attn_mask)
368
 
369
  x = bias_dropout_scale_fn(self.attn_out(x),
370
  None,
@@ -457,15 +494,18 @@ class DITBackbone(nn.Module):
457
  else:
458
  return bias_dropout_add_scale_fused_inference
459
 
460
- def gen_mask(self, seqlen, block_size):
461
- self_attn_mask = block_causal_mask(seqlen, block_size, mode='diag')
462
- x0_attn_mask = block_causal_mask(seqlen, block_size, mode='full')
463
- cross_attn_mask = x0_attn_mask.clone()
464
- cross_attn_mask.masked_fill_(self_attn_mask == 1, 0)
465
-
466
- cross_attn_mask = torch.cat((self_attn_mask, cross_attn_mask), dim=1)
467
- x0_attn_mask = torch.cat((torch.zeros_like(self_attn_mask), x0_attn_mask), dim=1)
468
- self.cross_attn_mask = torch.cat((cross_attn_mask, x0_attn_mask), dim=0)
 
 
 
469
 
470
  def forward(self, indices, sigma, sample_mode=False,
471
  store_kv=False, output_hidden_states=False):
@@ -478,13 +518,13 @@ class DITBackbone(nn.Module):
478
  c = F.silu(self.sigma_map(sigma))
479
  if self.cross_attn:
480
  rotary_cos_sin = self.rotary_emb(x[:, :self.n])
481
- cross_attn_mask = self.cross_attn_mask.to(x.device)
482
  # use block-causal mask only during sampling
483
  if sample_mode:
484
- cross_attn_mask = cross_attn_mask[
485
  self.n:self.n+x.shape[1], self.n:self.n+x.shape[1]]
486
  else:
487
- cross_attn_mask = None
488
  rotary_cos_sin = self.rotary_emb(x)
489
 
490
  with torch.cuda.amp.autocast(dtype=self.precision):
@@ -492,7 +532,7 @@ class DITBackbone(nn.Module):
492
  x = self.blocks[i](x,
493
  rotary_cos_sin,
494
  c,
495
- cross_attn_mask=cross_attn_mask,
496
  sample_mode=sample_mode,
497
  store_kv=store_kv)
498
  if output_hidden_states:
@@ -512,6 +552,7 @@ class BD3LM(transformers.PreTrainedModel):
512
  self,
513
  config: BD3LMConfig):
514
  super().__init__(config)
 
515
  self.backbone = DITBackbone(config)
516
  if config.var_min:
517
  self.register_buffer(
@@ -523,7 +564,7 @@ class BD3LM(transformers.PreTrainedModel):
523
 
524
  def reset_kv_cache(self):
525
  for block in self.backbone.blocks:
526
- block.kv_cache = None
527
 
528
  def forward(
529
  self,
@@ -537,6 +578,9 @@ class BD3LM(transformers.PreTrainedModel):
537
  torch.Tensor, typing.Tuple,
538
  modeling_outputs.MaskedLMOutput]:
539
  """HF-compatible forward method."""
 
 
 
540
  output_hidden_states = (
541
  output_hidden_states
542
  if output_hidden_states is not None
 
5
  import typing
6
 
7
  import einops
8
+ from functools import partial
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  import transformers
13
  from transformers import modeling_outputs
14
+ try:
15
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
16
+ FLEX_ATTN_AVAILABLE = True
17
+ except:
18
+ FLEX_ATTN_AVAILABLE = False
19
 
20
  from .configuration_bd3lm import BD3LMConfig
21
 
 
25
  torch._C._jit_override_can_fuse_on_cpu(True)
26
  torch._C._jit_override_can_fuse_on_gpu(True)
27
 
28
+ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
29
+ """
30
+ Constructs the specialized block diffusion attention mask for training
31
+ composed of three masks:
32
+ - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
33
+ - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
34
+ - **Block Causal Mask (M_BC)**: Attention to update x0
35
+
36
+ Args:
37
+ b, h: Batch and head indices (ignored for mask logic).
38
+ q_idx, kv_idx: Query and Key indices.
39
+ seq_len: Total sequence length.
40
+ block_size: Defines the block structure.
41
+
42
+ Returns:
43
+ A boolean attention mask.
44
+ """
45
+
46
+ # Indicate whether token belongs to xt or x0
47
+ x0_flag_q = (q_idx >= n)
48
+ x0_flag_kv = (kv_idx >= n)
49
+
50
+ # Compute block indices
51
+ block_q = torch.where(x0_flag_q == 1,
52
+ (q_idx - n) // block_size,
53
+ q_idx // block_size)
54
+ block_kv = torch.where(x0_flag_kv == 1,
55
+ (kv_idx - n) // block_size,
56
+ kv_idx // block_size)
57
+
58
+ # **1. Block Diagonal Mask (M_BD) **
59
+ block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
60
+
61
+ # **2. Offset Block-Causal Mask (M_OBC) **
62
+ offset_block_causal = (
63
+ (block_q > block_kv)
64
+ & (x0_flag_kv == 1)
65
+ & (x0_flag_q == 0)
66
+ )
67
+
68
+ # **3. Block-Causal Mask (M_BC) **
69
+ block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
70
+
71
+ # **4. Combine Masks **
72
+ return block_diagonal | offset_block_causal | block_causal
73
+
74
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
75
+ def fused_flex_attention(q, k, v, mask=None):
76
+ return flex_attention(q, k, v, block_mask=mask)
77
 
78
  def bias_dropout_add_scale(
79
  x: torch.Tensor,
 
170
  def apply_rotary_pos_emb_torchscript(qkv, cos, sin):
171
  return (qkv * cos) + (rotate_half(qkv) * sin)
172
 
 
 
 
 
 
 
173
  # function overload
174
  def modulate(x, shift, scale):
175
  return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
 
349
  h=self.n_heads)
350
  with torch.cuda.amp.autocast(enabled=False):
351
  cos, sin = rotary_cos_sin
352
+ qkv = apply_rotary_pos_emb_torchscript(
353
+ qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
 
 
 
 
354
  return qkv
355
 
356
+ def cross_attn(self, x, qkv, mask=None):
357
  scale = qkv.shape[-1]
358
  qkv = qkv.transpose(1, 3)
359
+ mask = mask.bool() if mask is not None else None
 
360
  x = F.scaled_dot_product_attention(
361
  query=qkv[:, :, 0],
362
  key=qkv[:, :, 1],
363
  value=qkv[:, :, 2],
364
+ attn_mask=mask,
 
365
  is_causal=False,
366
  scale=1 / math.sqrt(scale))
367
  x = x.transpose(1, 2)
368
  x = einops.rearrange(x, 'b s h d -> b s (h d)')
369
  return x
370
+
371
+ def cross_attn_flex(self, qkv, mask=None):
372
+ qkv = einops.rearrange(qkv, 'b s three h d -> b h three s d', h=self.n_heads)
373
+ x = fused_flex_attention(
374
+ qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], mask=mask)
375
+ x = einops.rearrange(x, 'b h s d -> b s (h d)')
376
+ return x
377
+
378
+ def forward(self, x, rotary_cos_sin, c, mask=None,
379
  sample_mode=False, store_kv=False):
380
  bias_dropout_scale_fn = self._get_bias_dropout_scale()
381
 
 
387
  x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
388
 
389
  # get qkvs
390
+ if mask is not None and not sample_mode:
391
  qkv_x = self.get_qkv(x[:,:self.n], rotary_cos_sin)
392
  qkv_x0 = self.get_qkv(x[:,self.n:], rotary_cos_sin)
393
  qkv = torch.cat((qkv_x, qkv_x0), dim=1)
394
  else:
395
  qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
396
 
397
+ if mask is None and self.attn_backend == 'flash_attn':
398
  x = regular_attention_multi_headed(qkv)
399
+ elif self.attn_backend == 'sdpa':
400
+ x = self.cross_attn(x, qkv, mask=mask)
401
+ elif self.attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
402
+ x = self.cross_attn_flex(qkv, mask=mask)
403
  else:
404
+ raise ValueError('Unknown attention backend')
405
 
406
  x = bias_dropout_scale_fn(self.attn_out(x),
407
  None,
 
494
  else:
495
  return bias_dropout_add_scale_fused_inference
496
 
497
+ def gen_mask(self, seqlen, block_size, attn_backend='sdpa'):
498
+ """Genererates attention mask"""
499
+ if attn_backend == 'sdpa':
500
+ self.block_diff_mask = block_diff_mask(
501
+ b=None, h=None, q_idx=torch.arange(seqlen*2)[:, None], kv_idx=torch.arange(seqlen*2)[None, :],
502
+ block_size=block_size, n=seqlen)
503
+ elif attn_backend == 'flex':
504
+ self.block_diff_mask = create_block_mask(
505
+ partial(block_diff_mask, block_size=block_size, n=seqlen),
506
+ B=None, H=None, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
507
+ else:
508
+ raise ValueError('Unknown attention backend')
509
 
510
  def forward(self, indices, sigma, sample_mode=False,
511
  store_kv=False, output_hidden_states=False):
 
518
  c = F.silu(self.sigma_map(sigma))
519
  if self.cross_attn:
520
  rotary_cos_sin = self.rotary_emb(x[:, :self.n])
521
+ mask = self.mask.to(x.device)
522
  # use block-causal mask only during sampling
523
  if sample_mode:
524
+ mask = mask[
525
  self.n:self.n+x.shape[1], self.n:self.n+x.shape[1]]
526
  else:
527
+ mask = None
528
  rotary_cos_sin = self.rotary_emb(x)
529
 
530
  with torch.cuda.amp.autocast(dtype=self.precision):
 
532
  x = self.blocks[i](x,
533
  rotary_cos_sin,
534
  c,
535
+ mask=mask,
536
  sample_mode=sample_mode,
537
  store_kv=store_kv)
538
  if output_hidden_states:
 
552
  self,
553
  config: BD3LMConfig):
554
  super().__init__(config)
555
+ self.config = config
556
  self.backbone = DITBackbone(config)
557
  if config.var_min:
558
  self.register_buffer(
 
564
 
565
  def reset_kv_cache(self):
566
  for block in self.backbone.blocks:
567
+ block.kv_cache = None
568
 
569
  def forward(
570
  self,
 
578
  torch.Tensor, typing.Tuple,
579
  modeling_outputs.MaskedLMOutput]:
580
  """HF-compatible forward method."""
581
+ if sample_mode:
582
+ assert self.config.attn_backend == 'sdpa', 'Sampling only supported with SDPA'
583
+
584
  output_hidden_states = (
585
  output_hidden_states
586
  if output_hidden_states is not None