marriola commited on
Commit
cff3f8f
·
verified ·
1 Parent(s): d5ecd4b

Upload BD3LM

Browse files
Files changed (1) hide show
  1. modeling_bd3lm.py +9 -8
modeling_bd3lm.py CHANGED
@@ -302,7 +302,6 @@ class DDiTBlock(nn.Module):
302
  def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
303
  dropout=0.1, max_seqlen=1024, attn_backend='flash_attn'):
304
  super().__init__()
305
- self.max_seqlen = max_seqlen
306
  self.n = n
307
  self.block_size = block_size
308
  self.n_heads = n_heads
@@ -341,7 +340,7 @@ class DDiTBlock(nn.Module):
341
  qkv = self.attn_qkv(x)
342
  # store kv cache in a sliding window (can't exceed context len)
343
  if store_kv:
344
- self.kv_cache = qkv[:, -(self.max_seqlen-self.block_size):]
345
 
346
  qkv = einops.rearrange(
347
  qkv,
@@ -389,8 +388,9 @@ class DDiTBlock(nn.Module):
389
 
390
  # get qkvs
391
  if mask is not None and not sample_mode:
392
- qkv_x = self.get_qkv(x[:,:self.n], rotary_cos_sin)
393
- qkv_x0 = self.get_qkv(x[:,self.n:], rotary_cos_sin)
 
394
  qkv = torch.cat((qkv_x, qkv_x0), dim=1)
395
  else:
396
  qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
@@ -518,12 +518,13 @@ class DITBackbone(nn.Module):
518
  all_hidden_states.append(x)
519
  c = F.silu(self.sigma_map(sigma))
520
  if self.cross_attn:
521
- rotary_cos_sin = self.rotary_emb(x[:, :self.n])
 
522
  mask = self.mask.to(x.device)
523
  # use block-causal mask only during sampling
524
  if sample_mode:
525
  mask = mask[
526
- self.n:self.n+x.shape[1], self.n:self.n+x.shape[1]]
527
  else:
528
  mask = None
529
  rotary_cos_sin = self.rotary_emb(x)
@@ -540,8 +541,8 @@ class DITBackbone(nn.Module):
540
  all_hidden_states.append(x)
541
  logits = self.output_layer(x, c)
542
  if self.cross_attn and not sample_mode:
543
- logits = logits[:, :self.n]
544
- all_hidden_states = [hidden_states[:, :self.n] for hidden_states in all_hidden_states]
545
  return logits, all_hidden_states
546
 
547
  class BD3LM(transformers.PreTrainedModel):
 
302
  def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
303
  dropout=0.1, max_seqlen=1024, attn_backend='flash_attn'):
304
  super().__init__()
 
305
  self.n = n
306
  self.block_size = block_size
307
  self.n_heads = n_heads
 
340
  qkv = self.attn_qkv(x)
341
  # store kv cache in a sliding window (can't exceed context len)
342
  if store_kv:
343
+ self.kv_cache = qkv[:, -(self.n-self.block_size):]
344
 
345
  qkv = einops.rearrange(
346
  qkv,
 
388
 
389
  # get qkvs
390
  if mask is not None and not sample_mode:
391
+ n = mask.shape[-1] // 2
392
+ qkv_x = self.get_qkv(x[:,:n], rotary_cos_sin)
393
+ qkv_x0 = self.get_qkv(x[:,n:], rotary_cos_sin)
394
  qkv = torch.cat((qkv_x, qkv_x0), dim=1)
395
  else:
396
  qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
 
518
  all_hidden_states.append(x)
519
  c = F.silu(self.sigma_map(sigma))
520
  if self.cross_attn:
521
+ n = self.mask.shape[-1] // 2
522
+ rotary_cos_sin = self.rotary_emb(x[:, :n])
523
  mask = self.mask.to(x.device)
524
  # use block-causal mask only during sampling
525
  if sample_mode:
526
  mask = mask[
527
+ n:n+x.shape[1], n:n+x.shape[1]]
528
  else:
529
  mask = None
530
  rotary_cos_sin = self.rotary_emb(x)
 
541
  all_hidden_states.append(x)
542
  logits = self.output_layer(x, c)
543
  if self.cross_attn and not sample_mode:
544
+ logits = logits[:, :n]
545
+ all_hidden_states = [hidden_states[:, :n] for hidden_states in all_hidden_states]
546
  return logits, all_hidden_states
547
 
548
  class BD3LM(transformers.PreTrainedModel):