Spaces:
Running
on
Zero
Running
on
Zero
from typing import Any, Dict, Optional | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from einops import rearrange, repeat | |
import random | |
def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3): | |
x_coord = torch.arange(kernel_size) | |
gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2)) | |
gaussian_1d = gaussian_1d / gaussian_1d.sum() | |
gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :] | |
kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1) | |
return kernel | |
def gaussian_filter(latents, kernel_size=3, sigma=1.0): | |
channels = latents.shape[1] | |
kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype) | |
blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels) | |
return blurred_latents | |
def get_views(height, width, h_window_size=64, w_window_size=64, scale_factor=8): | |
height = int(height) | |
width = int(width) | |
h_window_stride = h_window_size // 2 | |
w_window_stride = w_window_size // 2 | |
h_window_size = int(h_window_size / scale_factor) | |
w_window_size = int(w_window_size / scale_factor) | |
h_window_stride = int(h_window_stride / scale_factor) | |
w_window_stride = int(w_window_stride / scale_factor) | |
num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1 | |
num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1 | |
total_num_blocks = int(num_blocks_height * num_blocks_width) | |
views = [] | |
for i in range(total_num_blocks): | |
h_start = int((i // num_blocks_width) * h_window_stride) | |
h_end = h_start + h_window_size | |
w_start = int((i % num_blocks_width) * w_window_stride) | |
w_end = w_start + w_window_size | |
if h_end > height: | |
h_start = int(h_start + height - h_end) | |
h_end = int(height) | |
if w_end > width: | |
w_start = int(w_start + width - w_end) | |
w_end = int(width) | |
if h_start < 0: | |
h_end = int(h_end - h_start) | |
h_start = 0 | |
if w_start < 0: | |
w_end = int(w_end - w_start) | |
w_start = 0 | |
random_jitter = True | |
if random_jitter: | |
h_jitter_range = h_window_size // 8 | |
w_jitter_range = w_window_size // 8 | |
h_jitter = 0 | |
w_jitter = 0 | |
if (w_start != 0) and (w_end != width): | |
w_jitter = random.randint(-w_jitter_range, w_jitter_range) | |
elif (w_start == 0) and (w_end != width): | |
w_jitter = random.randint(-w_jitter_range, 0) | |
elif (w_start != 0) and (w_end == width): | |
w_jitter = random.randint(0, w_jitter_range) | |
if (h_start != 0) and (h_end != height): | |
h_jitter = random.randint(-h_jitter_range, h_jitter_range) | |
elif (h_start == 0) and (h_end != height): | |
h_jitter = random.randint(-h_jitter_range, 0) | |
elif (h_start != 0) and (h_end == height): | |
h_jitter = random.randint(0, h_jitter_range) | |
h_start += (h_jitter + h_jitter_range) | |
h_end += (h_jitter + h_jitter_range) | |
w_start += (w_jitter + w_jitter_range) | |
w_end += (w_jitter + w_jitter_range) | |
views.append((h_start, h_end, w_start, w_end)) | |
return views | |
def scale_forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
timestep: Optional[torch.LongTensor] = None, | |
cross_attention_kwargs: Dict[str, Any] = None, | |
class_labels: Optional[torch.LongTensor] = None, | |
): | |
# Notice that normalization is always applied before the real computation in the following blocks. | |
if self.current_hw: | |
current_scale_num_h, current_scale_num_w = self.current_hw[0] // 512, self.current_hw[1] // 512 | |
else: | |
current_scale_num_h, current_scale_num_w = 1, 1 | |
# 0. Self-Attention | |
if self.use_ada_layer_norm: | |
norm_hidden_states = self.norm1(hidden_states, timestep) | |
elif self.use_ada_layer_norm_zero: | |
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( | |
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype | |
) | |
else: | |
norm_hidden_states = self.norm1(hidden_states) | |
# 2. Prepare GLIGEN inputs | |
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} | |
gligen_kwargs = cross_attention_kwargs.pop("gligen", None) | |
ratio_hw = current_scale_num_h / current_scale_num_w | |
latent_h = int((norm_hidden_states.shape[1] * ratio_hw) ** 0.5) | |
latent_w = int(latent_h / ratio_hw) | |
scale_factor = 64 * current_scale_num_h / latent_h | |
if ratio_hw > 1: | |
sub_h = 64 | |
sub_w = int(64 / ratio_hw) | |
else: | |
sub_h = int(64 * ratio_hw) | |
sub_w = 64 | |
h_jitter_range = int(sub_h / scale_factor // 8) | |
w_jitter_range = int(sub_w / scale_factor // 8) | |
views = get_views(latent_h, latent_w, sub_h, sub_w, scale_factor = scale_factor) | |
current_scale_num = max(current_scale_num_h, current_scale_num_w) | |
global_views = [[h, w] for h in range(current_scale_num_h) for w in range(current_scale_num_w)] | |
if self.fast_mode: | |
four_window = False | |
fourg_window = True | |
else: | |
four_window = True | |
fourg_window = False | |
if four_window: | |
norm_hidden_states_ = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h) | |
norm_hidden_states_ = F.pad(norm_hidden_states_, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0) | |
value = torch.zeros_like(norm_hidden_states_) | |
count = torch.zeros_like(norm_hidden_states_) | |
for index, view in enumerate(views): | |
h_start, h_end, w_start, w_end = view | |
local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :] | |
local_states = rearrange(local_states, 'bh h w d -> bh (h w) d') | |
local_output = self.attn1( | |
local_states, | |
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor)) | |
value[:, h_start:h_end, w_start:w_end, :] += local_output * 1 | |
count[:, h_start:h_end, w_start:w_end, :] += 1 | |
value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :] | |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :] | |
attn_output = torch.where(count>0, value/count, value) | |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0) | |
attn_output_global = self.attn1( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h) | |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0) | |
attn_output = gaussian_local + (attn_output_global - gaussian_global) | |
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d') | |
elif fourg_window: | |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h) | |
norm_hidden_states_ = F.pad(norm_hidden_states, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0) | |
value = torch.zeros_like(norm_hidden_states_) | |
count = torch.zeros_like(norm_hidden_states_) | |
for index, view in enumerate(views): | |
h_start, h_end, w_start, w_end = view | |
local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :] | |
local_states = rearrange(local_states, 'bh h w d -> bh (h w) d') | |
local_output = self.attn1( | |
local_states, | |
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor)) | |
value[:, h_start:h_end, w_start:w_end, :] += local_output * 1 | |
count[:, h_start:h_end, w_start:w_end, :] += 1 | |
value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :] | |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :] | |
attn_output = torch.where(count>0, value/count, value) | |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0) | |
value = torch.zeros_like(norm_hidden_states) | |
count = torch.zeros_like(norm_hidden_states) | |
for index, global_view in enumerate(global_views): | |
h, w = global_view | |
global_states = norm_hidden_states[:, h::current_scale_num_h, w::current_scale_num_w, :] | |
global_states = rearrange(global_states, 'bh h w d -> bh (h w) d') | |
global_output = self.attn1( | |
global_states, | |
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
global_output = rearrange(global_output, 'bh (h w) d -> bh h w d', h = int(global_output.shape[1] ** 0.5)) | |
value[:, h::current_scale_num_h, w::current_scale_num_w, :] += global_output * 1 | |
count[:, h::current_scale_num_h, w::current_scale_num_w, :] += 1 | |
attn_output_global = torch.where(count>0, value/count, value) | |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0) | |
attn_output = gaussian_local + (attn_output_global - gaussian_global) | |
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d') | |
else: | |
attn_output = self.attn1( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
if self.use_ada_layer_norm_zero: | |
attn_output = gate_msa.unsqueeze(1) * attn_output | |
hidden_states = attn_output + hidden_states | |
# 2.5 GLIGEN Control | |
if gligen_kwargs is not None: | |
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) | |
# 2.5 ends | |
# 3. Cross-Attention | |
if self.attn2 is not None: | |
norm_hidden_states = ( | |
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) | |
) | |
attn_output = self.attn2( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=encoder_attention_mask, | |
**cross_attention_kwargs, | |
) | |
hidden_states = attn_output + hidden_states | |
# 4. Feed-forward | |
norm_hidden_states = self.norm3(hidden_states) | |
if self.use_ada_layer_norm_zero: | |
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
if self._chunk_size is not None: | |
# "feed_forward_chunk_size" can be used to save memory | |
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: | |
raise ValueError( | |
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." | |
) | |
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size | |
ff_output = torch.cat( | |
[ | |
self.ff(hid_slice) | |
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) | |
], | |
dim=self._chunk_dim, | |
) | |
else: | |
ff_output = self.ff(norm_hidden_states) | |
if self.use_ada_layer_norm_zero: | |
ff_output = gate_mlp.unsqueeze(1) * ff_output | |
hidden_states = ff_output + hidden_states | |
return hidden_states | |
def ori_forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
timestep: Optional[torch.LongTensor] = None, | |
cross_attention_kwargs: Dict[str, Any] = None, | |
class_labels: Optional[torch.LongTensor] = None, | |
): | |
# Notice that normalization is always applied before the real computation in the following blocks. | |
# 0. Self-Attention | |
if self.use_ada_layer_norm: | |
norm_hidden_states = self.norm1(hidden_states, timestep) | |
elif self.use_ada_layer_norm_zero: | |
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( | |
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype | |
) | |
else: | |
norm_hidden_states = self.norm1(hidden_states) | |
# 2. Prepare GLIGEN inputs | |
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} | |
gligen_kwargs = cross_attention_kwargs.pop("gligen", None) | |
attn_output = self.attn1( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
if self.use_ada_layer_norm_zero: | |
attn_output = gate_msa.unsqueeze(1) * attn_output | |
hidden_states = attn_output + hidden_states | |
# 2.5 GLIGEN Control | |
if gligen_kwargs is not None: | |
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) | |
# 2.5 ends | |
# 3. Cross-Attention | |
if self.attn2 is not None: | |
norm_hidden_states = ( | |
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) | |
) | |
attn_output = self.attn2( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=encoder_attention_mask, | |
**cross_attention_kwargs, | |
) | |
hidden_states = attn_output + hidden_states | |
# 4. Feed-forward | |
norm_hidden_states = self.norm3(hidden_states) | |
if self.use_ada_layer_norm_zero: | |
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
if self._chunk_size is not None: | |
# "feed_forward_chunk_size" can be used to save memory | |
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: | |
raise ValueError( | |
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." | |
) | |
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size | |
ff_output = torch.cat( | |
[ | |
self.ff(hid_slice) | |
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) | |
], | |
dim=self._chunk_dim, | |
) | |
else: | |
ff_output = self.ff(norm_hidden_states) | |
if self.use_ada_layer_norm_zero: | |
ff_output = gate_mlp.unsqueeze(1) * ff_output | |
hidden_states = ff_output + hidden_states | |
return hidden_states |