tolgacangoz
commited on
Upload matryoshka.py
Browse files- matryoshka.py +22 -64
matryoshka.py
CHANGED
@@ -420,6 +420,7 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
420 |
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
421 |
|
422 |
self.scales = None
|
|
|
423 |
|
424 |
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
425 |
"""
|
@@ -532,6 +533,8 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
532 |
|
533 |
def get_schedule_shifted(self, alpha_prod, scale_factor=None):
|
534 |
if (scale_factor is not None) and (scale_factor > 1): # rescale noise schedule
|
|
|
|
|
535 |
snr = alpha_prod / (1 - alpha_prod)
|
536 |
scaled_snr = snr / scale_factor
|
537 |
alpha_prod = 1 / (1 + 1 / scaled_snr)
|
@@ -1440,7 +1443,7 @@ class MatryoshkaTransformerBlock(nn.Module):
|
|
1440 |
bias=True,
|
1441 |
upcast_attention=upcast_attention,
|
1442 |
pre_only=True,
|
1443 |
-
processor=
|
1444 |
)
|
1445 |
self.attn1.fuse_projections()
|
1446 |
del self.attn1.to_q
|
@@ -1458,7 +1461,7 @@ class MatryoshkaTransformerBlock(nn.Module):
|
|
1458 |
bias=True,
|
1459 |
upcast_attention=upcast_attention,
|
1460 |
pre_only=True,
|
1461 |
-
processor=
|
1462 |
)
|
1463 |
self.attn2.fuse_projections()
|
1464 |
del self.attn2.to_q
|
@@ -1517,7 +1520,6 @@ class MatryoshkaTransformerBlock(nn.Module):
|
|
1517 |
# **cross_attention_kwargs,
|
1518 |
)
|
1519 |
|
1520 |
-
# attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous()
|
1521 |
attn_output_cond = self.proj_out(attn_output_cond)
|
1522 |
attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
|
1523 |
hidden_states = hidden_states + attn_output_cond
|
@@ -1535,7 +1537,7 @@ class MatryoshkaTransformerBlock(nn.Module):
|
|
1535 |
return hidden_states
|
1536 |
|
1537 |
|
1538 |
-
class
|
1539 |
r"""
|
1540 |
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
|
1541 |
fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
|
@@ -1548,28 +1550,11 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
|
|
1548 |
</Tip>
|
1549 |
"""
|
1550 |
|
1551 |
-
|
1552 |
-
|
1553 |
-
|
1554 |
-
|
1555 |
-
|
1556 |
-
|
1557 |
-
# TODO: They seem to give different results; but nevertheless can I replace this with torch.nn.functional.scaled_dot_product_attention()?
|
1558 |
-
def attention(self, q, k, v, num_heads, mask=None):
|
1559 |
-
bs, width, length = q.shape
|
1560 |
-
ch = width // num_heads
|
1561 |
-
scale = 1 / torch.sqrt(torch.sqrt(torch.tensor(ch)))
|
1562 |
-
weight = torch.einsum(
|
1563 |
-
"bct,bcs->bts",
|
1564 |
-
(q * scale).reshape(bs * num_heads, ch, length),
|
1565 |
-
(k * scale).reshape(bs * num_heads, ch, -1),
|
1566 |
-
) # More stable with f16 than dividing afterwards
|
1567 |
-
if mask is not None:
|
1568 |
-
mask = mask.view(mask.size(0), 1, 1, mask.size(-1)).repeat(1, num_heads, 1, 1).flatten(0, 1)
|
1569 |
-
weight = weight.masked_fill(mask == 0, float("-inf"))
|
1570 |
-
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
1571 |
-
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * num_heads, ch, -1))
|
1572 |
-
return a.reshape(bs, -1, length)
|
1573 |
|
1574 |
def __call__(
|
1575 |
self,
|
@@ -1593,26 +1578,12 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
|
|
1593 |
|
1594 |
input_ndim = hidden_states.ndim
|
1595 |
|
1596 |
-
if input_ndim == 4:
|
1597 |
-
batch_size, channel, height, width = hidden_states.shape
|
1598 |
-
# hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1599 |
-
|
1600 |
-
# batch_size, sequence_length, _ = (
|
1601 |
-
# hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1602 |
-
# )
|
1603 |
-
|
1604 |
-
# if attention_mask is not None:
|
1605 |
-
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1606 |
-
# # scaled_dot_product_attention expects attention_mask shape to be
|
1607 |
-
# # (batch, heads, source_length, target_length)
|
1608 |
-
# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1609 |
-
|
1610 |
if attn.group_norm is not None:
|
1611 |
-
hidden_states = attn.group_norm(hidden_states)
|
1612 |
|
1613 |
-
|
1614 |
-
|
1615 |
-
|
1616 |
|
1617 |
if encoder_hidden_states is None:
|
1618 |
qkv = attn.to_qkv(hidden_states)
|
@@ -1630,11 +1601,6 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
|
|
1630 |
split_size = kv.shape[-1] // 2
|
1631 |
key, value = torch.split(kv, split_size, dim=-1)
|
1632 |
|
1633 |
-
# if self_attention_output is None:
|
1634 |
-
# query = query.permute(0, 2, 1)
|
1635 |
-
# key = key.permute(0, 2, 1)
|
1636 |
-
# value = value.permute(0, 2, 1)
|
1637 |
-
|
1638 |
if attn.norm_q is not None:
|
1639 |
query = attn.norm_q(query)
|
1640 |
if attn.norm_k is not None:
|
@@ -1659,16 +1625,6 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
|
|
1659 |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1660 |
)
|
1661 |
|
1662 |
-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1663 |
-
# TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available
|
1664 |
-
# hidden_states = self.attention(
|
1665 |
-
# query,
|
1666 |
-
# key,
|
1667 |
-
# value,
|
1668 |
-
# mask=attention_mask,
|
1669 |
-
# num_heads=attn.heads,
|
1670 |
-
# )
|
1671 |
-
|
1672 |
hidden_states = hidden_states.to(query.dtype)
|
1673 |
|
1674 |
if self_attention_output is not None:
|
@@ -1956,7 +1912,7 @@ class MatryoshkaCombinedTimestepTextEmbedding(nn.Module):
|
|
1956 |
# if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False):
|
1957 |
return temb_micro_conditioning, conditioning_mask, cond_emb
|
1958 |
|
1959 |
-
return
|
1960 |
|
1961 |
|
1962 |
@dataclass
|
@@ -3184,7 +3140,7 @@ class MatryoshkaUNet2DConditionModel(
|
|
3184 |
encoder_hidden_states=encoder_hidden_states,
|
3185 |
attention_mask=attention_mask,
|
3186 |
cross_attention_kwargs=cross_attention_kwargs,
|
3187 |
-
encoder_attention_mask=encoder_attention_mask,
|
3188 |
**additional_residuals,
|
3189 |
)
|
3190 |
else:
|
@@ -3214,7 +3170,7 @@ class MatryoshkaUNet2DConditionModel(
|
|
3214 |
encoder_hidden_states=encoder_hidden_states,
|
3215 |
attention_mask=attention_mask,
|
3216 |
cross_attention_kwargs=cross_attention_kwargs,
|
3217 |
-
encoder_attention_mask=encoder_attention_mask,
|
3218 |
)
|
3219 |
else:
|
3220 |
sample = self.mid_block(sample, emb)
|
@@ -3251,7 +3207,7 @@ class MatryoshkaUNet2DConditionModel(
|
|
3251 |
cross_attention_kwargs=cross_attention_kwargs,
|
3252 |
upsample_size=upsample_size,
|
3253 |
attention_mask=attention_mask,
|
3254 |
-
encoder_attention_mask=encoder_attention_mask,
|
3255 |
)
|
3256 |
else:
|
3257 |
sample = upsample_block(
|
@@ -3699,7 +3655,7 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
|
|
3699 |
cross_attention_kwargs=cross_attention_kwargs,
|
3700 |
upsample_size=upsample_size,
|
3701 |
attention_mask=attention_mask,
|
3702 |
-
encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask,
|
3703 |
)
|
3704 |
else:
|
3705 |
sample = upsample_block(
|
@@ -3863,6 +3819,8 @@ class MatryoshkaPipeline(
|
|
3863 |
|
3864 |
if hasattr(unet, "nest_ratio"):
|
3865 |
scheduler.scales = unet.nest_ratio + [1]
|
|
|
|
|
3866 |
|
3867 |
self.register_modules(
|
3868 |
text_encoder=text_encoder,
|
|
|
420 |
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
421 |
|
422 |
self.scales = None
|
423 |
+
self.schedule_shifted_power = 1.0
|
424 |
|
425 |
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
426 |
"""
|
|
|
533 |
|
534 |
def get_schedule_shifted(self, alpha_prod, scale_factor=None):
|
535 |
if (scale_factor is not None) and (scale_factor > 1): # rescale noise schedule
|
536 |
+
p = self._config.schedule_shifted_power
|
537 |
+
scale_factor = scale_factor ** p
|
538 |
snr = alpha_prod / (1 - alpha_prod)
|
539 |
scaled_snr = snr / scale_factor
|
540 |
alpha_prod = 1 / (1 + 1 / scaled_snr)
|
|
|
1443 |
bias=True,
|
1444 |
upcast_attention=upcast_attention,
|
1445 |
pre_only=True,
|
1446 |
+
processor=MatryoshkaFusedAttnProcessor2_0(),
|
1447 |
)
|
1448 |
self.attn1.fuse_projections()
|
1449 |
del self.attn1.to_q
|
|
|
1461 |
bias=True,
|
1462 |
upcast_attention=upcast_attention,
|
1463 |
pre_only=True,
|
1464 |
+
processor=MatryoshkaFusedAttnProcessor2_0(),
|
1465 |
)
|
1466 |
self.attn2.fuse_projections()
|
1467 |
del self.attn2.to_q
|
|
|
1520 |
# **cross_attention_kwargs,
|
1521 |
)
|
1522 |
|
|
|
1523 |
attn_output_cond = self.proj_out(attn_output_cond)
|
1524 |
attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
|
1525 |
hidden_states = hidden_states + attn_output_cond
|
|
|
1537 |
return hidden_states
|
1538 |
|
1539 |
|
1540 |
+
class MatryoshkaFusedAttnProcessor2_0:
|
1541 |
r"""
|
1542 |
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
|
1543 |
fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
|
|
|
1550 |
</Tip>
|
1551 |
"""
|
1552 |
|
1553 |
+
def __init__(self):
|
1554 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1555 |
+
raise ImportError(
|
1556 |
+
"MatryoshkaFusedAttnProcessor2_0 requires PyTorch 2.x, to use it. Please upgrade PyTorch to > 2.x."
|
1557 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1558 |
|
1559 |
def __call__(
|
1560 |
self,
|
|
|
1578 |
|
1579 |
input_ndim = hidden_states.ndim
|
1580 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1581 |
if attn.group_norm is not None:
|
1582 |
+
hidden_states = attn.group_norm(hidden_states)
|
1583 |
|
1584 |
+
if input_ndim == 4:
|
1585 |
+
batch_size, channel, height, width = hidden_states.shape
|
1586 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2).contiguous()
|
1587 |
|
1588 |
if encoder_hidden_states is None:
|
1589 |
qkv = attn.to_qkv(hidden_states)
|
|
|
1601 |
split_size = kv.shape[-1] // 2
|
1602 |
key, value = torch.split(kv, split_size, dim=-1)
|
1603 |
|
|
|
|
|
|
|
|
|
|
|
1604 |
if attn.norm_q is not None:
|
1605 |
query = attn.norm_q(query)
|
1606 |
if attn.norm_k is not None:
|
|
|
1625 |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1626 |
)
|
1627 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1628 |
hidden_states = hidden_states.to(query.dtype)
|
1629 |
|
1630 |
if self_attention_output is not None:
|
|
|
1912 |
# if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False):
|
1913 |
return temb_micro_conditioning, conditioning_mask, cond_emb
|
1914 |
|
1915 |
+
return None, conditioning_mask, cond_emb
|
1916 |
|
1917 |
|
1918 |
@dataclass
|
|
|
3140 |
encoder_hidden_states=encoder_hidden_states,
|
3141 |
attention_mask=attention_mask,
|
3142 |
cross_attention_kwargs=cross_attention_kwargs,
|
3143 |
+
encoder_attention_mask=encoder_attention_mask,
|
3144 |
**additional_residuals,
|
3145 |
)
|
3146 |
else:
|
|
|
3170 |
encoder_hidden_states=encoder_hidden_states,
|
3171 |
attention_mask=attention_mask,
|
3172 |
cross_attention_kwargs=cross_attention_kwargs,
|
3173 |
+
encoder_attention_mask=encoder_attention_mask,
|
3174 |
)
|
3175 |
else:
|
3176 |
sample = self.mid_block(sample, emb)
|
|
|
3207 |
cross_attention_kwargs=cross_attention_kwargs,
|
3208 |
upsample_size=upsample_size,
|
3209 |
attention_mask=attention_mask,
|
3210 |
+
encoder_attention_mask=encoder_attention_mask,
|
3211 |
)
|
3212 |
else:
|
3213 |
sample = upsample_block(
|
|
|
3655 |
cross_attention_kwargs=cross_attention_kwargs,
|
3656 |
upsample_size=upsample_size,
|
3657 |
attention_mask=attention_mask,
|
3658 |
+
encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask,
|
3659 |
)
|
3660 |
else:
|
3661 |
sample = upsample_block(
|
|
|
3819 |
|
3820 |
if hasattr(unet, "nest_ratio"):
|
3821 |
scheduler.scales = unet.nest_ratio + [1]
|
3822 |
+
if nesting_level == 2:
|
3823 |
+
scheduler.schedule_shifted_power = 2.0
|
3824 |
|
3825 |
self.register_modules(
|
3826 |
text_encoder=text_encoder,
|