|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from diffusers.models import UNet2DConditionModel |
|
from diffusers.models.attention import Attention |
|
from diffusers.models.attention_processor import AttnProcessor2_0 |
|
|
|
|
|
|
|
def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> nn.Module: |
|
attn_procs = {} |
|
for key, attn_processor in unet.attn_processors.items(): |
|
if "attn1" in key: |
|
attn_procs[key] = ImageDreamAttnProcessor2_0() |
|
else: |
|
attn_procs[key] = attn_processor |
|
unet.set_attn_processor(attn_procs) |
|
return unet |
|
|
|
|
|
class ImageDreamAttnProcessor2_0(AttnProcessor2_0): |
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
temb: Optional[torch.Tensor] = None, |
|
num_views: int = 1, |
|
*args, |
|
**kwargs, |
|
): |
|
if num_views == 1: |
|
return super().__call__( |
|
attn=attn, |
|
hidden_states=hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
temb=temb, |
|
*args, |
|
**kwargs, |
|
) |
|
|
|
input_ndim = hidden_states.ndim |
|
B = hidden_states.size(0) |
|
if B % num_views: |
|
raise ValueError( |
|
f"`batch_size`(got {B}) must be a multiple of `num_views`(got {num_views})." |
|
) |
|
real_B = B // num_views |
|
if input_ndim == 4: |
|
H, W = hidden_states.shape[2:] |
|
hidden_states = hidden_states.reshape(real_B, -1, H, W).transpose(1, 2) |
|
else: |
|
hidden_states = hidden_states.reshape(real_B, -1, hidden_states.size(-1)) |
|
hidden_states = super().__call__( |
|
attn=attn, |
|
hidden_states=hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
temb=temb, |
|
*args, |
|
**kwargs, |
|
) |
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(B, -1, H, W) |
|
else: |
|
hidden_states = hidden_states.reshape(B, -1, hidden_states.size(-1)) |
|
return hidden_states |
|
|