Upload BD3LM
Browse files- modeling_bd3lm.py +9 -8
modeling_bd3lm.py
CHANGED
@@ -302,7 +302,6 @@ class DDiTBlock(nn.Module):
|
|
302 |
def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
|
303 |
dropout=0.1, max_seqlen=1024, attn_backend='flash_attn'):
|
304 |
super().__init__()
|
305 |
-
self.max_seqlen = max_seqlen
|
306 |
self.n = n
|
307 |
self.block_size = block_size
|
308 |
self.n_heads = n_heads
|
@@ -341,7 +340,7 @@ class DDiTBlock(nn.Module):
|
|
341 |
qkv = self.attn_qkv(x)
|
342 |
# store kv cache in a sliding window (can't exceed context len)
|
343 |
if store_kv:
|
344 |
-
self.kv_cache = qkv[:, -(self.
|
345 |
|
346 |
qkv = einops.rearrange(
|
347 |
qkv,
|
@@ -389,8 +388,9 @@ class DDiTBlock(nn.Module):
|
|
389 |
|
390 |
# get qkvs
|
391 |
if mask is not None and not sample_mode:
|
392 |
-
|
393 |
-
|
|
|
394 |
qkv = torch.cat((qkv_x, qkv_x0), dim=1)
|
395 |
else:
|
396 |
qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
|
@@ -518,12 +518,13 @@ class DITBackbone(nn.Module):
|
|
518 |
all_hidden_states.append(x)
|
519 |
c = F.silu(self.sigma_map(sigma))
|
520 |
if self.cross_attn:
|
521 |
-
|
|
|
522 |
mask = self.mask.to(x.device)
|
523 |
# use block-causal mask only during sampling
|
524 |
if sample_mode:
|
525 |
mask = mask[
|
526 |
-
|
527 |
else:
|
528 |
mask = None
|
529 |
rotary_cos_sin = self.rotary_emb(x)
|
@@ -540,8 +541,8 @@ class DITBackbone(nn.Module):
|
|
540 |
all_hidden_states.append(x)
|
541 |
logits = self.output_layer(x, c)
|
542 |
if self.cross_attn and not sample_mode:
|
543 |
-
logits = logits[:, :
|
544 |
-
all_hidden_states = [hidden_states[:, :
|
545 |
return logits, all_hidden_states
|
546 |
|
547 |
class BD3LM(transformers.PreTrainedModel):
|
|
|
302 |
def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
|
303 |
dropout=0.1, max_seqlen=1024, attn_backend='flash_attn'):
|
304 |
super().__init__()
|
|
|
305 |
self.n = n
|
306 |
self.block_size = block_size
|
307 |
self.n_heads = n_heads
|
|
|
340 |
qkv = self.attn_qkv(x)
|
341 |
# store kv cache in a sliding window (can't exceed context len)
|
342 |
if store_kv:
|
343 |
+
self.kv_cache = qkv[:, -(self.n-self.block_size):]
|
344 |
|
345 |
qkv = einops.rearrange(
|
346 |
qkv,
|
|
|
388 |
|
389 |
# get qkvs
|
390 |
if mask is not None and not sample_mode:
|
391 |
+
n = mask.shape[-1] // 2
|
392 |
+
qkv_x = self.get_qkv(x[:,:n], rotary_cos_sin)
|
393 |
+
qkv_x0 = self.get_qkv(x[:,n:], rotary_cos_sin)
|
394 |
qkv = torch.cat((qkv_x, qkv_x0), dim=1)
|
395 |
else:
|
396 |
qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
|
|
|
518 |
all_hidden_states.append(x)
|
519 |
c = F.silu(self.sigma_map(sigma))
|
520 |
if self.cross_attn:
|
521 |
+
n = self.mask.shape[-1] // 2
|
522 |
+
rotary_cos_sin = self.rotary_emb(x[:, :n])
|
523 |
mask = self.mask.to(x.device)
|
524 |
# use block-causal mask only during sampling
|
525 |
if sample_mode:
|
526 |
mask = mask[
|
527 |
+
n:n+x.shape[1], n:n+x.shape[1]]
|
528 |
else:
|
529 |
mask = None
|
530 |
rotary_cos_sin = self.rotary_emb(x)
|
|
|
541 |
all_hidden_states.append(x)
|
542 |
logits = self.output_layer(x, c)
|
543 |
if self.cross_attn and not sample_mode:
|
544 |
+
logits = logits[:, :n]
|
545 |
+
all_hidden_states = [hidden_states[:, :n] for hidden_states in all_hidden_states]
|
546 |
return logits, all_hidden_states
|
547 |
|
548 |
class BD3LM(transformers.PreTrainedModel):
|