tolgacangoz
commited on
Upload matryoshka.py
Browse files- matryoshka.py +2 -2
matryoshka.py
CHANGED
@@ -1612,7 +1612,7 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
|
|
1612 |
hidden_states = attn.group_norm(hidden_states) # .transpose(1, 2)).transpose(1, 2)
|
1613 |
|
1614 |
# Reshape hidden_states to 2D tensor
|
1615 |
-
hidden_states = hidden_states.view(batch_size, channel, height * width).permute(0, 2, 1)
|
1616 |
# Now hidden_states.shape is [batch_size, height * width, channels]
|
1617 |
|
1618 |
if encoder_hidden_states is None:
|
@@ -1664,8 +1664,8 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
|
|
1664 |
dropout_p=attn.dropout,
|
1665 |
)
|
1666 |
|
1667 |
-
hidden_states = hidden_states.to(query.dtype)
|
1668 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
|
1669 |
|
1670 |
if self_attention_output is not None:
|
1671 |
hidden_states = hidden_states + self_attention_output
|
|
|
1612 |
hidden_states = attn.group_norm(hidden_states) # .transpose(1, 2)).transpose(1, 2)
|
1613 |
|
1614 |
# Reshape hidden_states to 2D tensor
|
1615 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).permute(0, 2, 1)#.contiguous()
|
1616 |
# Now hidden_states.shape is [batch_size, height * width, channels]
|
1617 |
|
1618 |
if encoder_hidden_states is None:
|
|
|
1664 |
dropout_p=attn.dropout,
|
1665 |
)
|
1666 |
|
|
|
1667 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1668 |
+
hidden_states = hidden_states.to(query.dtype)
|
1669 |
|
1670 |
if self_attention_output is not None:
|
1671 |
hidden_states = hidden_states + self_attention_output
|