tolgacangoz commited on
Commit
889cc98
·
verified ·
1 Parent(s): 2b2d901

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. matryoshka.py +2 -2
matryoshka.py CHANGED
@@ -1612,7 +1612,7 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1612
  hidden_states = attn.group_norm(hidden_states) # .transpose(1, 2)).transpose(1, 2)
1613
 
1614
  # Reshape hidden_states to 2D tensor
1615
- hidden_states = hidden_states.view(batch_size, channel, height * width).permute(0, 2, 1).contiguous()
1616
  # Now hidden_states.shape is [batch_size, height * width, channels]
1617
 
1618
  if encoder_hidden_states is None:
@@ -1664,8 +1664,8 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1664
  dropout_p=attn.dropout,
1665
  )
1666
 
1667
- hidden_states = hidden_states.to(query.dtype)
1668
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
 
1669
 
1670
  if self_attention_output is not None:
1671
  hidden_states = hidden_states + self_attention_output
 
1612
  hidden_states = attn.group_norm(hidden_states) # .transpose(1, 2)).transpose(1, 2)
1613
 
1614
  # Reshape hidden_states to 2D tensor
1615
+ hidden_states = hidden_states.view(batch_size, channel, height * width).permute(0, 2, 1)#.contiguous()
1616
  # Now hidden_states.shape is [batch_size, height * width, channels]
1617
 
1618
  if encoder_hidden_states is None:
 
1664
  dropout_p=attn.dropout,
1665
  )
1666
 
 
1667
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1668
+ hidden_states = hidden_states.to(query.dtype)
1669
 
1670
  if self_attention_output is not None:
1671
  hidden_states = hidden_states + self_attention_output