tolgacangoz commited on
Commit
8a785c1
·
verified ·
1 Parent(s): 9f82cb6

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. matryoshka.py +22 -64
matryoshka.py CHANGED
@@ -420,6 +420,7 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
420
  self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
421
 
422
  self.scales = None
 
423
 
424
  def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
425
  """
@@ -532,6 +533,8 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
532
 
533
  def get_schedule_shifted(self, alpha_prod, scale_factor=None):
534
  if (scale_factor is not None) and (scale_factor > 1): # rescale noise schedule
 
 
535
  snr = alpha_prod / (1 - alpha_prod)
536
  scaled_snr = snr / scale_factor
537
  alpha_prod = 1 / (1 + 1 / scaled_snr)
@@ -1440,7 +1443,7 @@ class MatryoshkaTransformerBlock(nn.Module):
1440
  bias=True,
1441
  upcast_attention=upcast_attention,
1442
  pre_only=True,
1443
- processor=MatryoshkaFusedAttnProcessor1_0_or_2_0(),
1444
  )
1445
  self.attn1.fuse_projections()
1446
  del self.attn1.to_q
@@ -1458,7 +1461,7 @@ class MatryoshkaTransformerBlock(nn.Module):
1458
  bias=True,
1459
  upcast_attention=upcast_attention,
1460
  pre_only=True,
1461
- processor=MatryoshkaFusedAttnProcessor1_0_or_2_0(),
1462
  )
1463
  self.attn2.fuse_projections()
1464
  del self.attn2.to_q
@@ -1517,7 +1520,6 @@ class MatryoshkaTransformerBlock(nn.Module):
1517
  # **cross_attention_kwargs,
1518
  )
1519
 
1520
- # attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous()
1521
  attn_output_cond = self.proj_out(attn_output_cond)
1522
  attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
1523
  hidden_states = hidden_states + attn_output_cond
@@ -1535,7 +1537,7 @@ class MatryoshkaTransformerBlock(nn.Module):
1535
  return hidden_states
1536
 
1537
 
1538
- class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1539
  r"""
1540
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
1541
  fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
@@ -1548,28 +1550,11 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1548
  </Tip>
1549
  """
1550
 
1551
- # def __init__(self):
1552
- # if not hasattr(F, "scaled_dot_product_attention"):
1553
- # raise ImportError(
1554
- # "MatryoshkaFusedAttnProcessor2_0 requires PyTorch 2.x, to use it. Please upgrade PyTorch to > 2.x."
1555
- # )
1556
-
1557
- # TODO: They seem to give different results; but nevertheless can I replace this with torch.nn.functional.scaled_dot_product_attention()?
1558
- def attention(self, q, k, v, num_heads, mask=None):
1559
- bs, width, length = q.shape
1560
- ch = width // num_heads
1561
- scale = 1 / torch.sqrt(torch.sqrt(torch.tensor(ch)))
1562
- weight = torch.einsum(
1563
- "bct,bcs->bts",
1564
- (q * scale).reshape(bs * num_heads, ch, length),
1565
- (k * scale).reshape(bs * num_heads, ch, -1),
1566
- ) # More stable with f16 than dividing afterwards
1567
- if mask is not None:
1568
- mask = mask.view(mask.size(0), 1, 1, mask.size(-1)).repeat(1, num_heads, 1, 1).flatten(0, 1)
1569
- weight = weight.masked_fill(mask == 0, float("-inf"))
1570
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
1571
- a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * num_heads, ch, -1))
1572
- return a.reshape(bs, -1, length)
1573
 
1574
  def __call__(
1575
  self,
@@ -1593,26 +1578,12 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1593
 
1594
  input_ndim = hidden_states.ndim
1595
 
1596
- if input_ndim == 4:
1597
- batch_size, channel, height, width = hidden_states.shape
1598
- # hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1599
-
1600
- # batch_size, sequence_length, _ = (
1601
- # hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1602
- # )
1603
-
1604
- # if attention_mask is not None:
1605
- # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1606
- # # scaled_dot_product_attention expects attention_mask shape to be
1607
- # # (batch, heads, source_length, target_length)
1608
- # attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1609
-
1610
  if attn.group_norm is not None:
1611
- hidden_states = attn.group_norm(hidden_states) # .transpose(1, 2)).transpose(1, 2)
1612
 
1613
- # Reshape hidden_states to 2D tensor
1614
- hidden_states = hidden_states.view(batch_size, channel, height * width).permute(0, 2, 1).contiguous()
1615
- # Now hidden_states.shape is [batch_size, height * width, channels]
1616
 
1617
  if encoder_hidden_states is None:
1618
  qkv = attn.to_qkv(hidden_states)
@@ -1630,11 +1601,6 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1630
  split_size = kv.shape[-1] // 2
1631
  key, value = torch.split(kv, split_size, dim=-1)
1632
 
1633
- # if self_attention_output is None:
1634
- # query = query.permute(0, 2, 1)
1635
- # key = key.permute(0, 2, 1)
1636
- # value = value.permute(0, 2, 1)
1637
-
1638
  if attn.norm_q is not None:
1639
  query = attn.norm_q(query)
1640
  if attn.norm_k is not None:
@@ -1659,16 +1625,6 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1659
  query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1660
  )
1661
 
1662
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
1663
- # TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available
1664
- # hidden_states = self.attention(
1665
- # query,
1666
- # key,
1667
- # value,
1668
- # mask=attention_mask,
1669
- # num_heads=attn.heads,
1670
- # )
1671
-
1672
  hidden_states = hidden_states.to(query.dtype)
1673
 
1674
  if self_attention_output is not None:
@@ -1956,7 +1912,7 @@ class MatryoshkaCombinedTimestepTextEmbedding(nn.Module):
1956
  # if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False):
1957
  return temb_micro_conditioning, conditioning_mask, cond_emb
1958
 
1959
- return cond_emb, conditioning_mask, cond_emb
1960
 
1961
 
1962
  @dataclass
@@ -3184,7 +3140,7 @@ class MatryoshkaUNet2DConditionModel(
3184
  encoder_hidden_states=encoder_hidden_states,
3185
  attention_mask=attention_mask,
3186
  cross_attention_kwargs=cross_attention_kwargs,
3187
- encoder_attention_mask=encoder_attention_mask, # cond_mask?
3188
  **additional_residuals,
3189
  )
3190
  else:
@@ -3214,7 +3170,7 @@ class MatryoshkaUNet2DConditionModel(
3214
  encoder_hidden_states=encoder_hidden_states,
3215
  attention_mask=attention_mask,
3216
  cross_attention_kwargs=cross_attention_kwargs,
3217
- encoder_attention_mask=encoder_attention_mask, # cond_mask?
3218
  )
3219
  else:
3220
  sample = self.mid_block(sample, emb)
@@ -3251,7 +3207,7 @@ class MatryoshkaUNet2DConditionModel(
3251
  cross_attention_kwargs=cross_attention_kwargs,
3252
  upsample_size=upsample_size,
3253
  attention_mask=attention_mask,
3254
- encoder_attention_mask=encoder_attention_mask, # cond_mask?
3255
  )
3256
  else:
3257
  sample = upsample_block(
@@ -3699,7 +3655,7 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
3699
  cross_attention_kwargs=cross_attention_kwargs,
3700
  upsample_size=upsample_size,
3701
  attention_mask=attention_mask,
3702
- encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask, # cond_mask?
3703
  )
3704
  else:
3705
  sample = upsample_block(
@@ -3863,6 +3819,8 @@ class MatryoshkaPipeline(
3863
 
3864
  if hasattr(unet, "nest_ratio"):
3865
  scheduler.scales = unet.nest_ratio + [1]
 
 
3866
 
3867
  self.register_modules(
3868
  text_encoder=text_encoder,
 
420
  self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
421
 
422
  self.scales = None
423
+ self.schedule_shifted_power = 1.0
424
 
425
  def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
426
  """
 
533
 
534
  def get_schedule_shifted(self, alpha_prod, scale_factor=None):
535
  if (scale_factor is not None) and (scale_factor > 1): # rescale noise schedule
536
+ p = self._config.schedule_shifted_power
537
+ scale_factor = scale_factor ** p
538
  snr = alpha_prod / (1 - alpha_prod)
539
  scaled_snr = snr / scale_factor
540
  alpha_prod = 1 / (1 + 1 / scaled_snr)
 
1443
  bias=True,
1444
  upcast_attention=upcast_attention,
1445
  pre_only=True,
1446
+ processor=MatryoshkaFusedAttnProcessor2_0(),
1447
  )
1448
  self.attn1.fuse_projections()
1449
  del self.attn1.to_q
 
1461
  bias=True,
1462
  upcast_attention=upcast_attention,
1463
  pre_only=True,
1464
+ processor=MatryoshkaFusedAttnProcessor2_0(),
1465
  )
1466
  self.attn2.fuse_projections()
1467
  del self.attn2.to_q
 
1520
  # **cross_attention_kwargs,
1521
  )
1522
 
 
1523
  attn_output_cond = self.proj_out(attn_output_cond)
1524
  attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
1525
  hidden_states = hidden_states + attn_output_cond
 
1537
  return hidden_states
1538
 
1539
 
1540
+ class MatryoshkaFusedAttnProcessor2_0:
1541
  r"""
1542
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
1543
  fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
 
1550
  </Tip>
1551
  """
1552
 
1553
+ def __init__(self):
1554
+ if not hasattr(F, "scaled_dot_product_attention"):
1555
+ raise ImportError(
1556
+ "MatryoshkaFusedAttnProcessor2_0 requires PyTorch 2.x, to use it. Please upgrade PyTorch to > 2.x."
1557
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1558
 
1559
  def __call__(
1560
  self,
 
1578
 
1579
  input_ndim = hidden_states.ndim
1580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1581
  if attn.group_norm is not None:
1582
+ hidden_states = attn.group_norm(hidden_states)
1583
 
1584
+ if input_ndim == 4:
1585
+ batch_size, channel, height, width = hidden_states.shape
1586
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2).contiguous()
1587
 
1588
  if encoder_hidden_states is None:
1589
  qkv = attn.to_qkv(hidden_states)
 
1601
  split_size = kv.shape[-1] // 2
1602
  key, value = torch.split(kv, split_size, dim=-1)
1603
 
 
 
 
 
 
1604
  if attn.norm_q is not None:
1605
  query = attn.norm_q(query)
1606
  if attn.norm_k is not None:
 
1625
  query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1626
  )
1627
 
 
 
 
 
 
 
 
 
 
 
1628
  hidden_states = hidden_states.to(query.dtype)
1629
 
1630
  if self_attention_output is not None:
 
1912
  # if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False):
1913
  return temb_micro_conditioning, conditioning_mask, cond_emb
1914
 
1915
+ return None, conditioning_mask, cond_emb
1916
 
1917
 
1918
  @dataclass
 
3140
  encoder_hidden_states=encoder_hidden_states,
3141
  attention_mask=attention_mask,
3142
  cross_attention_kwargs=cross_attention_kwargs,
3143
+ encoder_attention_mask=encoder_attention_mask,
3144
  **additional_residuals,
3145
  )
3146
  else:
 
3170
  encoder_hidden_states=encoder_hidden_states,
3171
  attention_mask=attention_mask,
3172
  cross_attention_kwargs=cross_attention_kwargs,
3173
+ encoder_attention_mask=encoder_attention_mask,
3174
  )
3175
  else:
3176
  sample = self.mid_block(sample, emb)
 
3207
  cross_attention_kwargs=cross_attention_kwargs,
3208
  upsample_size=upsample_size,
3209
  attention_mask=attention_mask,
3210
+ encoder_attention_mask=encoder_attention_mask,
3211
  )
3212
  else:
3213
  sample = upsample_block(
 
3655
  cross_attention_kwargs=cross_attention_kwargs,
3656
  upsample_size=upsample_size,
3657
  attention_mask=attention_mask,
3658
+ encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask,
3659
  )
3660
  else:
3661
  sample = upsample_block(
 
3819
 
3820
  if hasattr(unet, "nest_ratio"):
3821
  scheduler.scales = unet.nest_ratio + [1]
3822
+ if nesting_level == 2:
3823
+ scheduler.schedule_shifted_power = 2.0
3824
 
3825
  self.register_modules(
3826
  text_encoder=text_encoder,