imagedream-ipmv-diffusers / attention_processor.py
kiigii's picture
Upload folder using huggingface_hub
aaa261a verified
raw
history blame
2.34 kB
from typing import Optional
import torch
import torch.nn as nn
from diffusers.models import UNet2DConditionModel
from diffusers.models.attention import Attention
from diffusers.models.attention_processor import AttnProcessor2_0
def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> nn.Module:
attn_procs = {}
for key, attn_processor in unet.attn_processors.items():
if "attn1" in key:
attn_procs[key] = ImageDreamAttnProcessor2_0()
else:
attn_procs[key] = attn_processor
unet.set_attn_processor(attn_procs)
return unet
class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
num_views: int = 1,
*args,
**kwargs,
):
if num_views == 1:
return super().__call__(
attn=attn,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
temb=temb,
*args,
**kwargs,
)
input_ndim = hidden_states.ndim
B = hidden_states.size(0)
if B % num_views:
raise ValueError(
f"`batch_size`(got {B}) must be a multiple of `num_views`(got {num_views})."
)
real_B = B // num_views
if input_ndim == 4:
H, W = hidden_states.shape[2:]
hidden_states = hidden_states.reshape(real_B, -1, H, W).transpose(1, 2)
else:
hidden_states = hidden_states.reshape(real_B, -1, hidden_states.size(-1))
hidden_states = super().__call__(
attn=attn,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
temb=temb,
*args,
**kwargs,
)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(B, -1, H, W)
else:
hidden_states = hidden_states.reshape(B, -1, hidden_states.size(-1))
return hidden_states