Upload BD3LM
Browse files- config.json +2 -0
- configuration_bd3lm.py +4 -0
- modeling_bd3lm.py +86 -40
config.json
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
{
|
2 |
"_name_or_path": "kuleshov-group/bd3lm-owt-block_size4",
|
|
|
3 |
"architectures": [
|
4 |
"BD3LM"
|
5 |
],
|
@@ -9,6 +10,7 @@
|
|
9 |
"AutoModelForMaskedLM": "modeling_bd3lm.BD3LM"
|
10 |
},
|
11 |
"block_size": 4,
|
|
|
12 |
"cond_dim": 128,
|
13 |
"cross_attn": true,
|
14 |
"dropout": 0.1,
|
|
|
1 |
{
|
2 |
"_name_or_path": "kuleshov-group/bd3lm-owt-block_size4",
|
3 |
+
"adaln": true,
|
4 |
"architectures": [
|
5 |
"BD3LM"
|
6 |
],
|
|
|
10 |
"AutoModelForMaskedLM": "modeling_bd3lm.BD3LM"
|
11 |
},
|
12 |
"block_size": 4,
|
13 |
+
"causal": false,
|
14 |
"cond_dim": 128,
|
15 |
"cross_attn": true,
|
16 |
"dropout": 0.1,
|
configuration_bd3lm.py
CHANGED
@@ -15,7 +15,9 @@ class BD3LMConfig(transformers.PretrainedConfig):
|
|
15 |
vocab_size: int = 50258,
|
16 |
model_length: int = 1024,
|
17 |
cross_attn: bool = True,
|
|
|
18 |
attn_backend: str = 'flex',
|
|
|
19 |
hidden_dim: int = 768,
|
20 |
cond_dim: int = 129,
|
21 |
n_blocks: int = 12,
|
@@ -29,7 +31,9 @@ class BD3LMConfig(transformers.PretrainedConfig):
|
|
29 |
super().__init__(**kwargs)
|
30 |
self.block_size = block_size
|
31 |
self.cross_attn = cross_attn
|
|
|
32 |
self.attn_backend = attn_backend
|
|
|
33 |
self.vocab_size = vocab_size
|
34 |
self.model_length = model_length
|
35 |
self.hidden_dim = hidden_dim
|
|
|
15 |
vocab_size: int = 50258,
|
16 |
model_length: int = 1024,
|
17 |
cross_attn: bool = True,
|
18 |
+
adaln: bool = True,
|
19 |
attn_backend: str = 'flex',
|
20 |
+
causal: bool = False,
|
21 |
hidden_dim: int = 768,
|
22 |
cond_dim: int = 129,
|
23 |
n_blocks: int = 12,
|
|
|
31 |
super().__init__(**kwargs)
|
32 |
self.block_size = block_size
|
33 |
self.cross_attn = cross_attn
|
34 |
+
self.adaln = adaln
|
35 |
self.attn_backend = attn_backend
|
36 |
+
self.causal = causal
|
37 |
self.vocab_size = vocab_size
|
38 |
self.model_length = model_length
|
39 |
self.hidden_dim = hidden_dim
|
modeling_bd3lm.py
CHANGED
@@ -296,14 +296,16 @@ def regular_attention_multi_headed(qkv):
|
|
296 |
|
297 |
|
298 |
class DDiTBlock(nn.Module):
|
299 |
-
def __init__(self, n, block_size, dim, n_heads, cond_dim,
|
300 |
-
dropout=0.1, attn_backend='sdpa'):
|
301 |
super().__init__()
|
302 |
self.n = n
|
303 |
self.block_size = block_size
|
304 |
self.n_heads = n_heads
|
305 |
self.attn_backend = attn_backend
|
306 |
self.kv_cache = None
|
|
|
|
|
307 |
|
308 |
self.norm1 = LayerNorm(dim)
|
309 |
self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
|
@@ -317,10 +319,11 @@ class DDiTBlock(nn.Module):
|
|
317 |
nn.Linear(mlp_ratio * dim, dim, bias=True))
|
318 |
self.dropout2 = nn.Dropout(dropout)
|
319 |
self.dropout = dropout
|
320 |
-
|
321 |
-
self.
|
322 |
-
|
323 |
-
|
|
|
324 |
|
325 |
def _get_bias_dropout_scale(self):
|
326 |
if self.training:
|
@@ -331,13 +334,18 @@ class DDiTBlock(nn.Module):
|
|
331 |
def get_qkv(self, x, rotary_cos_sin, store_kv=False):
|
332 |
# compute qkv (potentially use cache)
|
333 |
if self.kv_cache is not None:
|
334 |
-
new_qkv = self.attn_qkv(x
|
335 |
-
|
|
|
336 |
else:
|
337 |
qkv = self.attn_qkv(x)
|
338 |
# store kv cache in a sliding window (can't exceed context len)
|
339 |
if store_kv:
|
340 |
-
self.
|
|
|
|
|
|
|
|
|
341 |
|
342 |
qkv = einops.rearrange(
|
343 |
qkv,
|
@@ -359,7 +367,7 @@ class DDiTBlock(nn.Module):
|
|
359 |
key=qkv[:, :, 1],
|
360 |
value=qkv[:, :, 2],
|
361 |
attn_mask=mask,
|
362 |
-
is_causal=
|
363 |
scale=1 / math.sqrt(scale))
|
364 |
x = x.transpose(1, 2)
|
365 |
x = einops.rearrange(x, 'b s h d -> b s (h d)')
|
@@ -376,12 +384,16 @@ class DDiTBlock(nn.Module):
|
|
376 |
sample_mode=False, store_kv=False):
|
377 |
bias_dropout_scale_fn = self._get_bias_dropout_scale()
|
378 |
|
379 |
-
|
380 |
-
|
|
|
381 |
|
382 |
# attention operation
|
383 |
x_skip = x
|
384 |
-
|
|
|
|
|
|
|
385 |
|
386 |
# get qkvs
|
387 |
if mask is not None and not sample_mode:
|
@@ -399,17 +411,25 @@ class DDiTBlock(nn.Module):
|
|
399 |
else:
|
400 |
raise ValueError('Unknown attention backend')
|
401 |
|
402 |
-
|
|
|
|
|
403 |
None,
|
404 |
gate_msa,
|
405 |
x_skip,
|
406 |
self.dropout)
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
413 |
return x
|
414 |
|
415 |
|
@@ -424,23 +444,28 @@ class EmbeddingLayer(nn.Module):
|
|
424 |
|
425 |
|
426 |
class DDitFinalLayer(nn.Module):
|
427 |
-
def __init__(self, hidden_size, out_channels, cond_dim):
|
428 |
super().__init__()
|
429 |
self.norm_final = LayerNorm(hidden_size)
|
430 |
self.linear = nn.Linear(hidden_size, out_channels)
|
431 |
self.linear.weight.data.zero_()
|
432 |
self.linear.bias.data.zero_()
|
433 |
|
434 |
-
self.
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
|
|
|
|
439 |
|
440 |
|
441 |
def forward(self, x, c):
|
442 |
-
|
443 |
-
|
|
|
|
|
|
|
444 |
x = self.linear(x)
|
445 |
return x
|
446 |
|
@@ -460,8 +485,10 @@ class DITBackbone(nn.Module):
|
|
460 |
self.vocab_embed = EmbeddingLayer(
|
461 |
config.hidden_dim,
|
462 |
config.vocab_size)
|
463 |
-
self.
|
464 |
-
|
|
|
|
|
465 |
self.rotary_emb = Rotary(
|
466 |
config.hidden_dim // config.n_heads)
|
467 |
|
@@ -472,14 +499,17 @@ class DITBackbone(nn.Module):
|
|
472 |
config.hidden_dim,
|
473 |
config.n_heads,
|
474 |
config.cond_dim,
|
|
|
475 |
dropout=config.dropout,
|
|
|
476 |
attn_backend=config.attn_backend,))
|
477 |
self.blocks = nn.ModuleList(blocks)
|
478 |
|
479 |
self.output_layer = DDitFinalLayer(
|
480 |
config.hidden_dim,
|
481 |
config.vocab_size,
|
482 |
-
config.cond_dim
|
|
|
483 |
if self.cross_attn:
|
484 |
self.gen_mask(config.model_length, self.block_size, attn_backend=config.attn_backend)
|
485 |
self.precision = torch.float32
|
@@ -505,21 +535,31 @@ class DITBackbone(nn.Module):
|
|
505 |
|
506 |
def forward(self, indices, sigma, sample_mode=False,
|
507 |
store_kv=False, output_hidden_states=False):
|
508 |
-
if not self.config.time_conditioning:
|
509 |
sigma = torch.zeros_like(sigma)
|
510 |
all_hidden_states = []
|
511 |
x = self.vocab_embed(indices)
|
512 |
if output_hidden_states:
|
513 |
all_hidden_states.append(x)
|
514 |
-
c =
|
|
|
|
|
515 |
if self.cross_attn:
|
516 |
n = self.mask.shape[-1] // 2
|
517 |
-
rotary_cos_sin = self.rotary_emb(x[:, :n])
|
518 |
-
mask = self.mask.to(x.device)
|
519 |
# use block-causal mask only during sampling
|
520 |
if sample_mode:
|
521 |
-
|
522 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
else:
|
524 |
mask = None
|
525 |
rotary_cos_sin = self.rotary_emb(x)
|
@@ -558,10 +598,16 @@ class BD3LM(transformers.PreTrainedModel):
|
|
558 |
self.register_buffer(
|
559 |
'sampling_eps_max',
|
560 |
torch.tensor(config.sampling_eps_max))
|
561 |
-
|
562 |
def reset_kv_cache(self):
|
563 |
-
for block in self.
|
564 |
-
block.kv_cache =
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
|
566 |
def forward(
|
567 |
self,
|
|
|
296 |
|
297 |
|
298 |
class DDiTBlock(nn.Module):
|
299 |
+
def __init__(self, n, block_size, dim, n_heads, cond_dim, causal=False,
|
300 |
+
mlp_ratio=4, dropout=0.1, adaln=True, attn_backend='sdpa'):
|
301 |
super().__init__()
|
302 |
self.n = n
|
303 |
self.block_size = block_size
|
304 |
self.n_heads = n_heads
|
305 |
self.attn_backend = attn_backend
|
306 |
self.kv_cache = None
|
307 |
+
self.cache_idx = 0
|
308 |
+
self.causal = causal
|
309 |
|
310 |
self.norm1 = LayerNorm(dim)
|
311 |
self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
|
|
|
319 |
nn.Linear(mlp_ratio * dim, dim, bias=True))
|
320 |
self.dropout2 = nn.Dropout(dropout)
|
321 |
self.dropout = dropout
|
322 |
+
self.adaln = adaln
|
323 |
+
if self.adaln:
|
324 |
+
self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
|
325 |
+
self.adaLN_modulation.weight.data.zero_()
|
326 |
+
self.adaLN_modulation.bias.data.zero_()
|
327 |
|
328 |
def _get_bias_dropout_scale(self):
|
329 |
if self.training:
|
|
|
334 |
def get_qkv(self, x, rotary_cos_sin, store_kv=False):
|
335 |
# compute qkv (potentially use cache)
|
336 |
if self.kv_cache is not None:
|
337 |
+
new_qkv = self.attn_qkv(x)
|
338 |
+
self.kv_cache[:, self.cache_idx:self.cache_idx+self.block_size] = new_qkv
|
339 |
+
qkv = self.kv_cache[:, :self.cache_idx+self.block_size].clone()
|
340 |
else:
|
341 |
qkv = self.attn_qkv(x)
|
342 |
# store kv cache in a sliding window (can't exceed context len)
|
343 |
if store_kv:
|
344 |
+
self.cache_idx += self.block_size
|
345 |
+
if self.cache_idx >= self.n:
|
346 |
+
# left-shift the cache
|
347 |
+
self.cache_idx = self.n - self.block_size
|
348 |
+
self.kv_cache[:, :-self.block_size] = self.kv_cache[:, self.block_size:].clone()
|
349 |
|
350 |
qkv = einops.rearrange(
|
351 |
qkv,
|
|
|
367 |
key=qkv[:, :, 1],
|
368 |
value=qkv[:, :, 2],
|
369 |
attn_mask=mask,
|
370 |
+
is_causal=self.causal,
|
371 |
scale=1 / math.sqrt(scale))
|
372 |
x = x.transpose(1, 2)
|
373 |
x = einops.rearrange(x, 'b s h d -> b s (h d)')
|
|
|
384 |
sample_mode=False, store_kv=False):
|
385 |
bias_dropout_scale_fn = self._get_bias_dropout_scale()
|
386 |
|
387 |
+
if self.adaln:
|
388 |
+
(shift_msa, scale_msa, gate_msa, shift_mlp,
|
389 |
+
scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
|
390 |
|
391 |
# attention operation
|
392 |
x_skip = x
|
393 |
+
if self.adaln:
|
394 |
+
x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
|
395 |
+
else:
|
396 |
+
x = self.norm1(x)
|
397 |
|
398 |
# get qkvs
|
399 |
if mask is not None and not sample_mode:
|
|
|
411 |
else:
|
412 |
raise ValueError('Unknown attention backend')
|
413 |
|
414 |
+
# mlp operation
|
415 |
+
if self.adaln:
|
416 |
+
x = bias_dropout_scale_fn(self.attn_out(x),
|
417 |
None,
|
418 |
gate_msa,
|
419 |
x_skip,
|
420 |
self.dropout)
|
421 |
+
x = bias_dropout_scale_fn(
|
422 |
+
self.mlp(modulate_fused(
|
423 |
+
self.norm2(x), shift_mlp, scale_mlp)),
|
424 |
+
None, gate_mlp, x, self.dropout)
|
425 |
+
else:
|
426 |
+
x = bias_dropout_scale_fn(self.attn_out(x),
|
427 |
+
None, torch.ones_like(x), x_skip, self.dropout)
|
428 |
+
x = bias_dropout_scale_fn(
|
429 |
+
self.mlp(self.norm2(x)),
|
430 |
+
None, torch.ones_like(x), x, self.dropout)
|
431 |
+
if self.kv_cache is not None:
|
432 |
+
x = x[:, -self.block_size:]
|
433 |
return x
|
434 |
|
435 |
|
|
|
444 |
|
445 |
|
446 |
class DDitFinalLayer(nn.Module):
|
447 |
+
def __init__(self, hidden_size, out_channels, cond_dim, adaln=True):
|
448 |
super().__init__()
|
449 |
self.norm_final = LayerNorm(hidden_size)
|
450 |
self.linear = nn.Linear(hidden_size, out_channels)
|
451 |
self.linear.weight.data.zero_()
|
452 |
self.linear.bias.data.zero_()
|
453 |
|
454 |
+
self.adaln = adaln
|
455 |
+
if self.adaln:
|
456 |
+
self.adaLN_modulation = nn.Linear(cond_dim,
|
457 |
+
2 * hidden_size,
|
458 |
+
bias=True)
|
459 |
+
self.adaLN_modulation.weight.data.zero_()
|
460 |
+
self.adaLN_modulation.bias.data.zero_()
|
461 |
|
462 |
|
463 |
def forward(self, x, c):
|
464 |
+
if self.adaln:
|
465 |
+
shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
|
466 |
+
x = modulate_fused(self.norm_final(x), shift, scale)
|
467 |
+
else:
|
468 |
+
x = self.norm_final(x)
|
469 |
x = self.linear(x)
|
470 |
return x
|
471 |
|
|
|
485 |
self.vocab_embed = EmbeddingLayer(
|
486 |
config.hidden_dim,
|
487 |
config.vocab_size)
|
488 |
+
self.adaln = config.adaln
|
489 |
+
if self.adaln:
|
490 |
+
self.sigma_map = TimestepEmbedder(
|
491 |
+
config.cond_dim)
|
492 |
self.rotary_emb = Rotary(
|
493 |
config.hidden_dim // config.n_heads)
|
494 |
|
|
|
499 |
config.hidden_dim,
|
500 |
config.n_heads,
|
501 |
config.cond_dim,
|
502 |
+
causal=config.causal,
|
503 |
dropout=config.dropout,
|
504 |
+
adaln=config.adaln,
|
505 |
attn_backend=config.attn_backend,))
|
506 |
self.blocks = nn.ModuleList(blocks)
|
507 |
|
508 |
self.output_layer = DDitFinalLayer(
|
509 |
config.hidden_dim,
|
510 |
config.vocab_size,
|
511 |
+
config.cond_dim,
|
512 |
+
adaln=config.adaln)
|
513 |
if self.cross_attn:
|
514 |
self.gen_mask(config.model_length, self.block_size, attn_backend=config.attn_backend)
|
515 |
self.precision = torch.float32
|
|
|
535 |
|
536 |
def forward(self, indices, sigma, sample_mode=False,
|
537 |
store_kv=False, output_hidden_states=False):
|
538 |
+
if not self.config.time_conditioning and self.adaln:
|
539 |
sigma = torch.zeros_like(sigma)
|
540 |
all_hidden_states = []
|
541 |
x = self.vocab_embed(indices)
|
542 |
if output_hidden_states:
|
543 |
all_hidden_states.append(x)
|
544 |
+
c = None
|
545 |
+
if self.adaln:
|
546 |
+
c = F.silu(self.sigma_map(sigma))
|
547 |
if self.cross_attn:
|
548 |
n = self.mask.shape[-1] // 2
|
|
|
|
|
549 |
# use block-causal mask only during sampling
|
550 |
if sample_mode:
|
551 |
+
if self.blocks[0].kv_cache is not None:
|
552 |
+
mask = None
|
553 |
+
accum_length = self.blocks[0].cache_idx + self.block_size
|
554 |
+
# positional encodings for cache
|
555 |
+
x_full = torch.zeros((
|
556 |
+
x.shape[0], accum_length, x.shape[2]), device=x.device)
|
557 |
+
rotary_cos_sin = self.rotary_emb(x_full)
|
558 |
+
else:
|
559 |
+
mask = self.mask.to(x.device)
|
560 |
+
rotary_cos_sin = self.rotary_emb(x[:, :n])
|
561 |
+
mask = mask[
|
562 |
+
n:n+x.shape[1], n:n+x.shape[1]]
|
563 |
else:
|
564 |
mask = None
|
565 |
rotary_cos_sin = self.rotary_emb(x)
|
|
|
598 |
self.register_buffer(
|
599 |
'sampling_eps_max',
|
600 |
torch.tensor(config.sampling_eps_max))
|
601 |
+
|
602 |
def reset_kv_cache(self):
|
603 |
+
for block in self.blocks:
|
604 |
+
block.kv_cache = torch.zeros(
|
605 |
+
self.config.loader.eval_batch_size,
|
606 |
+
self.n,
|
607 |
+
self.config.model.hidden_size * 3,
|
608 |
+
device='cuda',
|
609 |
+
dtype=torch.bfloat16)
|
610 |
+
block.cache_idx = 0
|
611 |
|
612 |
def forward(
|
613 |
self,
|