Spaces:
Build error
Build error
import torch | |
import torch.nn.functional as F | |
from diffusers.models.attention_processor import ( | |
Attention, | |
AttnProcessor2_0, | |
SlicedAttnProcessor, | |
XFormersAttnProcessor | |
) | |
try: | |
import xformers.ops | |
except: | |
xformers = None | |
loaded_networks = [] | |
def apply_single_hypernetwork( | |
hypernetwork, hidden_states, encoder_hidden_states | |
): | |
context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states) | |
return context_k, context_v | |
def apply_hypernetworks(context_k, context_v, layer=None): | |
if len(loaded_networks) == 0: | |
return context_v, context_v | |
for hypernetwork in loaded_networks: | |
context_k, context_v = hypernetwork.forward(context_k, context_v) | |
context_k = context_k.to(dtype=context_k.dtype) | |
context_v = context_v.to(dtype=context_k.dtype) | |
return context_k, context_v | |
def xformers_forward( | |
self: XFormersAttnProcessor, | |
attn: Attention, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor = None, | |
attention_mask: torch.Tensor = None, | |
): | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size | |
) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) | |
key = attn.to_k(context_k) | |
value = attn.to_v(context_v) | |
query = attn.head_to_batch_dim(query).contiguous() | |
key = attn.head_to_batch_dim(key).contiguous() | |
value = attn.head_to_batch_dim(value).contiguous() | |
hidden_states = xformers.ops.memory_efficient_attention( | |
query, | |
key, | |
value, | |
attn_bias=attention_mask, | |
op=self.attention_op, | |
scale=attn.scale, | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
return hidden_states | |
def sliced_attn_forward( | |
self: SlicedAttnProcessor, | |
attn: Attention, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor = None, | |
attention_mask: torch.Tensor = None, | |
): | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size | |
) | |
query = attn.to_q(hidden_states) | |
dim = query.shape[-1] | |
query = attn.head_to_batch_dim(query) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) | |
key = attn.to_k(context_k) | |
value = attn.to_v(context_v) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
batch_size_attention, query_tokens, _ = query.shape | |
hidden_states = torch.zeros( | |
(batch_size_attention, query_tokens, dim // attn.heads), | |
device=query.device, | |
dtype=query.dtype, | |
) | |
for i in range(batch_size_attention // self.slice_size): | |
start_idx = i * self.slice_size | |
end_idx = (i + 1) * self.slice_size | |
query_slice = query[start_idx:end_idx] | |
key_slice = key[start_idx:end_idx] | |
attn_mask_slice = ( | |
attention_mask[start_idx:end_idx] if attention_mask is not None else None | |
) | |
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) | |
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) | |
hidden_states[start_idx:end_idx] = attn_slice | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
return hidden_states | |
def v2_0_forward( | |
self: AttnProcessor2_0, | |
attn: Attention, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
): | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
inner_dim = hidden_states.shape[-1] | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size | |
) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view( | |
batch_size, attn.heads, -1, attention_mask.shape[-1] | |
) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) | |
key = attn.to_k(context_k) | |
value = attn.to_v(context_v) | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape( | |
batch_size, -1, attn.heads * head_dim | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
return hidden_states | |
def replace_attentions_for_hypernetwork(): | |
import diffusers.models.attention_processor | |
diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = ( | |
xformers_forward | |
) | |
diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = ( | |
sliced_attn_forward | |
) | |
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward | |