|
import torch |
|
from einops import rearrange |
|
from diffusers.models.attention import Attention |
|
from .globals import get_enhance_weight, get_num_frames |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_feta_scores( |
|
attn: Attention, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
head_dim: int, |
|
text_seq_length: int, |
|
) -> torch.Tensor: |
|
num_frames = get_num_frames() |
|
spatial_dim = int((query.shape[2] - text_seq_length) / num_frames) |
|
|
|
query_image = rearrange( |
|
query[:, :, text_seq_length:], |
|
"B N (T S) C -> (B S) N T C", |
|
N=attn.heads, |
|
T=num_frames, |
|
S=spatial_dim, |
|
C=head_dim, |
|
) |
|
key_image = rearrange( |
|
key[:, :, text_seq_length:], |
|
"B N (T S) C -> (B S) N T C", |
|
N=attn.heads, |
|
T=num_frames, |
|
S=spatial_dim, |
|
C=head_dim, |
|
) |
|
return feta_score(query_image, key_image, head_dim, num_frames) |
|
|
|
def feta_score(query_image, key_image, head_dim, num_frames): |
|
scale = head_dim**-0.5 |
|
query_image = query_image * scale |
|
attn_temp = query_image @ key_image.transpose(-2, -1) |
|
attn_temp = attn_temp.to(torch.float32) |
|
attn_temp = attn_temp.softmax(dim=-1) |
|
|
|
|
|
attn_temp = attn_temp.reshape(-1, num_frames, num_frames) |
|
|
|
|
|
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool() |
|
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1) |
|
|
|
|
|
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0) |
|
|
|
|
|
|
|
num_off_diag = num_frames * num_frames - num_frames |
|
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag |
|
|
|
enhance_scores = mean_scores.mean() * (num_frames + get_enhance_weight()) |
|
enhance_scores = enhance_scores.clamp(min=1) |
|
return enhance_scores |
|
|