marriola commited on
Commit
080d807
·
verified ·
1 Parent(s): 7e1f98c

Upload BD3LM

Browse files
Files changed (3) hide show
  1. config.json +2 -0
  2. configuration_bd3lm.py +4 -0
  3. modeling_bd3lm.py +86 -40
config.json CHANGED
@@ -1,5 +1,6 @@
1
  {
2
  "_name_or_path": "kuleshov-group/bd3lm-owt-block_size4",
 
3
  "architectures": [
4
  "BD3LM"
5
  ],
@@ -9,6 +10,7 @@
9
  "AutoModelForMaskedLM": "modeling_bd3lm.BD3LM"
10
  },
11
  "block_size": 4,
 
12
  "cond_dim": 128,
13
  "cross_attn": true,
14
  "dropout": 0.1,
 
1
  {
2
  "_name_or_path": "kuleshov-group/bd3lm-owt-block_size4",
3
+ "adaln": true,
4
  "architectures": [
5
  "BD3LM"
6
  ],
 
10
  "AutoModelForMaskedLM": "modeling_bd3lm.BD3LM"
11
  },
12
  "block_size": 4,
13
+ "causal": false,
14
  "cond_dim": 128,
15
  "cross_attn": true,
16
  "dropout": 0.1,
configuration_bd3lm.py CHANGED
@@ -15,7 +15,9 @@ class BD3LMConfig(transformers.PretrainedConfig):
15
  vocab_size: int = 50258,
16
  model_length: int = 1024,
17
  cross_attn: bool = True,
 
18
  attn_backend: str = 'flex',
 
19
  hidden_dim: int = 768,
20
  cond_dim: int = 129,
21
  n_blocks: int = 12,
@@ -29,7 +31,9 @@ class BD3LMConfig(transformers.PretrainedConfig):
29
  super().__init__(**kwargs)
30
  self.block_size = block_size
31
  self.cross_attn = cross_attn
 
32
  self.attn_backend = attn_backend
 
33
  self.vocab_size = vocab_size
34
  self.model_length = model_length
35
  self.hidden_dim = hidden_dim
 
15
  vocab_size: int = 50258,
16
  model_length: int = 1024,
17
  cross_attn: bool = True,
18
+ adaln: bool = True,
19
  attn_backend: str = 'flex',
20
+ causal: bool = False,
21
  hidden_dim: int = 768,
22
  cond_dim: int = 129,
23
  n_blocks: int = 12,
 
31
  super().__init__(**kwargs)
32
  self.block_size = block_size
33
  self.cross_attn = cross_attn
34
+ self.adaln = adaln
35
  self.attn_backend = attn_backend
36
+ self.causal = causal
37
  self.vocab_size = vocab_size
38
  self.model_length = model_length
39
  self.hidden_dim = hidden_dim
modeling_bd3lm.py CHANGED
@@ -296,14 +296,16 @@ def regular_attention_multi_headed(qkv):
296
 
297
 
298
  class DDiTBlock(nn.Module):
299
- def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
300
- dropout=0.1, attn_backend='sdpa'):
301
  super().__init__()
302
  self.n = n
303
  self.block_size = block_size
304
  self.n_heads = n_heads
305
  self.attn_backend = attn_backend
306
  self.kv_cache = None
 
 
307
 
308
  self.norm1 = LayerNorm(dim)
309
  self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
@@ -317,10 +319,11 @@ class DDiTBlock(nn.Module):
317
  nn.Linear(mlp_ratio * dim, dim, bias=True))
318
  self.dropout2 = nn.Dropout(dropout)
319
  self.dropout = dropout
320
-
321
- self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
322
- self.adaLN_modulation.weight.data.zero_()
323
- self.adaLN_modulation.bias.data.zero_()
 
324
 
325
  def _get_bias_dropout_scale(self):
326
  if self.training:
@@ -331,13 +334,18 @@ class DDiTBlock(nn.Module):
331
  def get_qkv(self, x, rotary_cos_sin, store_kv=False):
332
  # compute qkv (potentially use cache)
333
  if self.kv_cache is not None:
334
- new_qkv = self.attn_qkv(x[:, -self.block_size:])
335
- qkv = torch.cat((self.kv_cache, new_qkv), dim=1)
 
336
  else:
337
  qkv = self.attn_qkv(x)
338
  # store kv cache in a sliding window (can't exceed context len)
339
  if store_kv:
340
- self.kv_cache = qkv[:, -(self.n-self.block_size):]
 
 
 
 
341
 
342
  qkv = einops.rearrange(
343
  qkv,
@@ -359,7 +367,7 @@ class DDiTBlock(nn.Module):
359
  key=qkv[:, :, 1],
360
  value=qkv[:, :, 2],
361
  attn_mask=mask,
362
- is_causal=False,
363
  scale=1 / math.sqrt(scale))
364
  x = x.transpose(1, 2)
365
  x = einops.rearrange(x, 'b s h d -> b s (h d)')
@@ -376,12 +384,16 @@ class DDiTBlock(nn.Module):
376
  sample_mode=False, store_kv=False):
377
  bias_dropout_scale_fn = self._get_bias_dropout_scale()
378
 
379
- (shift_msa, scale_msa, gate_msa, shift_mlp,
380
- scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
 
381
 
382
  # attention operation
383
  x_skip = x
384
- x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
 
 
 
385
 
386
  # get qkvs
387
  if mask is not None and not sample_mode:
@@ -399,17 +411,25 @@ class DDiTBlock(nn.Module):
399
  else:
400
  raise ValueError('Unknown attention backend')
401
 
402
- x = bias_dropout_scale_fn(self.attn_out(x),
 
 
403
  None,
404
  gate_msa,
405
  x_skip,
406
  self.dropout)
407
-
408
- # mlp operation
409
- x = bias_dropout_scale_fn(
410
- self.mlp(modulate_fused(
411
- self.norm2(x), shift_mlp, scale_mlp)),
412
- None, gate_mlp, x, self.dropout)
 
 
 
 
 
 
413
  return x
414
 
415
 
@@ -424,23 +444,28 @@ class EmbeddingLayer(nn.Module):
424
 
425
 
426
  class DDitFinalLayer(nn.Module):
427
- def __init__(self, hidden_size, out_channels, cond_dim):
428
  super().__init__()
429
  self.norm_final = LayerNorm(hidden_size)
430
  self.linear = nn.Linear(hidden_size, out_channels)
431
  self.linear.weight.data.zero_()
432
  self.linear.bias.data.zero_()
433
 
434
- self.adaLN_modulation = nn.Linear(cond_dim,
435
- 2 * hidden_size,
436
- bias=True)
437
- self.adaLN_modulation.weight.data.zero_()
438
- self.adaLN_modulation.bias.data.zero_()
 
 
439
 
440
 
441
  def forward(self, x, c):
442
- shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
443
- x = modulate_fused(self.norm_final(x), shift, scale)
 
 
 
444
  x = self.linear(x)
445
  return x
446
 
@@ -460,8 +485,10 @@ class DITBackbone(nn.Module):
460
  self.vocab_embed = EmbeddingLayer(
461
  config.hidden_dim,
462
  config.vocab_size)
463
- self.sigma_map = TimestepEmbedder(
464
- config.cond_dim)
 
 
465
  self.rotary_emb = Rotary(
466
  config.hidden_dim // config.n_heads)
467
 
@@ -472,14 +499,17 @@ class DITBackbone(nn.Module):
472
  config.hidden_dim,
473
  config.n_heads,
474
  config.cond_dim,
 
475
  dropout=config.dropout,
 
476
  attn_backend=config.attn_backend,))
477
  self.blocks = nn.ModuleList(blocks)
478
 
479
  self.output_layer = DDitFinalLayer(
480
  config.hidden_dim,
481
  config.vocab_size,
482
- config.cond_dim)
 
483
  if self.cross_attn:
484
  self.gen_mask(config.model_length, self.block_size, attn_backend=config.attn_backend)
485
  self.precision = torch.float32
@@ -505,21 +535,31 @@ class DITBackbone(nn.Module):
505
 
506
  def forward(self, indices, sigma, sample_mode=False,
507
  store_kv=False, output_hidden_states=False):
508
- if not self.config.time_conditioning:
509
  sigma = torch.zeros_like(sigma)
510
  all_hidden_states = []
511
  x = self.vocab_embed(indices)
512
  if output_hidden_states:
513
  all_hidden_states.append(x)
514
- c = F.silu(self.sigma_map(sigma))
 
 
515
  if self.cross_attn:
516
  n = self.mask.shape[-1] // 2
517
- rotary_cos_sin = self.rotary_emb(x[:, :n])
518
- mask = self.mask.to(x.device)
519
  # use block-causal mask only during sampling
520
  if sample_mode:
521
- mask = mask[
522
- n:n+x.shape[1], n:n+x.shape[1]]
 
 
 
 
 
 
 
 
 
 
523
  else:
524
  mask = None
525
  rotary_cos_sin = self.rotary_emb(x)
@@ -558,10 +598,16 @@ class BD3LM(transformers.PreTrainedModel):
558
  self.register_buffer(
559
  'sampling_eps_max',
560
  torch.tensor(config.sampling_eps_max))
561
-
562
  def reset_kv_cache(self):
563
- for block in self.backbone.blocks:
564
- block.kv_cache = None
 
 
 
 
 
 
565
 
566
  def forward(
567
  self,
 
296
 
297
 
298
  class DDiTBlock(nn.Module):
299
+ def __init__(self, n, block_size, dim, n_heads, cond_dim, causal=False,
300
+ mlp_ratio=4, dropout=0.1, adaln=True, attn_backend='sdpa'):
301
  super().__init__()
302
  self.n = n
303
  self.block_size = block_size
304
  self.n_heads = n_heads
305
  self.attn_backend = attn_backend
306
  self.kv_cache = None
307
+ self.cache_idx = 0
308
+ self.causal = causal
309
 
310
  self.norm1 = LayerNorm(dim)
311
  self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
 
319
  nn.Linear(mlp_ratio * dim, dim, bias=True))
320
  self.dropout2 = nn.Dropout(dropout)
321
  self.dropout = dropout
322
+ self.adaln = adaln
323
+ if self.adaln:
324
+ self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
325
+ self.adaLN_modulation.weight.data.zero_()
326
+ self.adaLN_modulation.bias.data.zero_()
327
 
328
  def _get_bias_dropout_scale(self):
329
  if self.training:
 
334
  def get_qkv(self, x, rotary_cos_sin, store_kv=False):
335
  # compute qkv (potentially use cache)
336
  if self.kv_cache is not None:
337
+ new_qkv = self.attn_qkv(x)
338
+ self.kv_cache[:, self.cache_idx:self.cache_idx+self.block_size] = new_qkv
339
+ qkv = self.kv_cache[:, :self.cache_idx+self.block_size].clone()
340
  else:
341
  qkv = self.attn_qkv(x)
342
  # store kv cache in a sliding window (can't exceed context len)
343
  if store_kv:
344
+ self.cache_idx += self.block_size
345
+ if self.cache_idx >= self.n:
346
+ # left-shift the cache
347
+ self.cache_idx = self.n - self.block_size
348
+ self.kv_cache[:, :-self.block_size] = self.kv_cache[:, self.block_size:].clone()
349
 
350
  qkv = einops.rearrange(
351
  qkv,
 
367
  key=qkv[:, :, 1],
368
  value=qkv[:, :, 2],
369
  attn_mask=mask,
370
+ is_causal=self.causal,
371
  scale=1 / math.sqrt(scale))
372
  x = x.transpose(1, 2)
373
  x = einops.rearrange(x, 'b s h d -> b s (h d)')
 
384
  sample_mode=False, store_kv=False):
385
  bias_dropout_scale_fn = self._get_bias_dropout_scale()
386
 
387
+ if self.adaln:
388
+ (shift_msa, scale_msa, gate_msa, shift_mlp,
389
+ scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
390
 
391
  # attention operation
392
  x_skip = x
393
+ if self.adaln:
394
+ x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
395
+ else:
396
+ x = self.norm1(x)
397
 
398
  # get qkvs
399
  if mask is not None and not sample_mode:
 
411
  else:
412
  raise ValueError('Unknown attention backend')
413
 
414
+ # mlp operation
415
+ if self.adaln:
416
+ x = bias_dropout_scale_fn(self.attn_out(x),
417
  None,
418
  gate_msa,
419
  x_skip,
420
  self.dropout)
421
+ x = bias_dropout_scale_fn(
422
+ self.mlp(modulate_fused(
423
+ self.norm2(x), shift_mlp, scale_mlp)),
424
+ None, gate_mlp, x, self.dropout)
425
+ else:
426
+ x = bias_dropout_scale_fn(self.attn_out(x),
427
+ None, torch.ones_like(x), x_skip, self.dropout)
428
+ x = bias_dropout_scale_fn(
429
+ self.mlp(self.norm2(x)),
430
+ None, torch.ones_like(x), x, self.dropout)
431
+ if self.kv_cache is not None:
432
+ x = x[:, -self.block_size:]
433
  return x
434
 
435
 
 
444
 
445
 
446
  class DDitFinalLayer(nn.Module):
447
+ def __init__(self, hidden_size, out_channels, cond_dim, adaln=True):
448
  super().__init__()
449
  self.norm_final = LayerNorm(hidden_size)
450
  self.linear = nn.Linear(hidden_size, out_channels)
451
  self.linear.weight.data.zero_()
452
  self.linear.bias.data.zero_()
453
 
454
+ self.adaln = adaln
455
+ if self.adaln:
456
+ self.adaLN_modulation = nn.Linear(cond_dim,
457
+ 2 * hidden_size,
458
+ bias=True)
459
+ self.adaLN_modulation.weight.data.zero_()
460
+ self.adaLN_modulation.bias.data.zero_()
461
 
462
 
463
  def forward(self, x, c):
464
+ if self.adaln:
465
+ shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
466
+ x = modulate_fused(self.norm_final(x), shift, scale)
467
+ else:
468
+ x = self.norm_final(x)
469
  x = self.linear(x)
470
  return x
471
 
 
485
  self.vocab_embed = EmbeddingLayer(
486
  config.hidden_dim,
487
  config.vocab_size)
488
+ self.adaln = config.adaln
489
+ if self.adaln:
490
+ self.sigma_map = TimestepEmbedder(
491
+ config.cond_dim)
492
  self.rotary_emb = Rotary(
493
  config.hidden_dim // config.n_heads)
494
 
 
499
  config.hidden_dim,
500
  config.n_heads,
501
  config.cond_dim,
502
+ causal=config.causal,
503
  dropout=config.dropout,
504
+ adaln=config.adaln,
505
  attn_backend=config.attn_backend,))
506
  self.blocks = nn.ModuleList(blocks)
507
 
508
  self.output_layer = DDitFinalLayer(
509
  config.hidden_dim,
510
  config.vocab_size,
511
+ config.cond_dim,
512
+ adaln=config.adaln)
513
  if self.cross_attn:
514
  self.gen_mask(config.model_length, self.block_size, attn_backend=config.attn_backend)
515
  self.precision = torch.float32
 
535
 
536
  def forward(self, indices, sigma, sample_mode=False,
537
  store_kv=False, output_hidden_states=False):
538
+ if not self.config.time_conditioning and self.adaln:
539
  sigma = torch.zeros_like(sigma)
540
  all_hidden_states = []
541
  x = self.vocab_embed(indices)
542
  if output_hidden_states:
543
  all_hidden_states.append(x)
544
+ c = None
545
+ if self.adaln:
546
+ c = F.silu(self.sigma_map(sigma))
547
  if self.cross_attn:
548
  n = self.mask.shape[-1] // 2
 
 
549
  # use block-causal mask only during sampling
550
  if sample_mode:
551
+ if self.blocks[0].kv_cache is not None:
552
+ mask = None
553
+ accum_length = self.blocks[0].cache_idx + self.block_size
554
+ # positional encodings for cache
555
+ x_full = torch.zeros((
556
+ x.shape[0], accum_length, x.shape[2]), device=x.device)
557
+ rotary_cos_sin = self.rotary_emb(x_full)
558
+ else:
559
+ mask = self.mask.to(x.device)
560
+ rotary_cos_sin = self.rotary_emb(x[:, :n])
561
+ mask = mask[
562
+ n:n+x.shape[1], n:n+x.shape[1]]
563
  else:
564
  mask = None
565
  rotary_cos_sin = self.rotary_emb(x)
 
598
  self.register_buffer(
599
  'sampling_eps_max',
600
  torch.tensor(config.sampling_eps_max))
601
+
602
  def reset_kv_cache(self):
603
+ for block in self.blocks:
604
+ block.kv_cache = torch.zeros(
605
+ self.config.loader.eval_batch_size,
606
+ self.n,
607
+ self.config.model.hidden_size * 3,
608
+ device='cuda',
609
+ dtype=torch.bfloat16)
610
+ block.cache_idx = 0
611
 
612
  def forward(
613
  self,