XavierJiezou's picture
Upload folder using huggingface_hub
3c8ff2e verified
import copy
import numpy as np
import torch
import torch.nn as nn
from src.backbones.positional_encoding import PositionalEncoder
class LTAE2d(nn.Module):
def __init__(
self,
in_channels=128,
n_head=16,
d_k=4,
mlp=[256, 128],
dropout=0.2,
d_model=256,
T=1000,
return_att=False,
positional_encoding=True,
use_dropout=True
):
"""
Lightweight Temporal Attention Encoder (L-TAE) for image time series.
Attention-based sequence encoding that maps a sequence of images to a single feature map.
A shared L-TAE is applied to all pixel positions of the image sequence.
Args:
in_channels (int): Number of channels of the input embeddings.
n_head (int): Number of attention heads.
d_k (int): Dimension of the key and query vectors.
mlp (List[int]): Widths of the layers of the MLP that processes the concatenated outputs of the attention heads.
dropout (float): dropout on the MLP-processed values
d_model (int, optional): If specified, the input tensors will first processed by a fully connected layer
to project them into a feature space of dimension d_model.
T (int): Period to use for the positional encoding.
return_att (bool): If true, the module returns the attention masks along with the embeddings (default False)
positional_encoding (bool): If False, no positional encoding is used (default True).
use_dropout (bool): dropout on the attention masks.
"""
super(LTAE2d, self).__init__()
self.in_channels = in_channels
self.mlp = copy.deepcopy(mlp)
self.return_att = return_att
self.n_head = n_head
if d_model is not None:
self.d_model = d_model
self.inconv = nn.Conv1d(in_channels, d_model, 1)
else:
self.d_model = in_channels
self.inconv = None
assert self.mlp[0] == self.d_model
if positional_encoding:
self.positional_encoder = PositionalEncoder(
self.d_model // n_head, T=T, repeat=n_head
)
else:
self.positional_encoder = None
self.attention_heads = MultiHeadAttention(
n_head=n_head, d_k=d_k, d_in=self.d_model, use_dropout=use_dropout
)
self.in_norm = nn.GroupNorm(
num_groups=n_head,
num_channels=self.in_channels,
)
self.out_norm = nn.GroupNorm(
num_groups=n_head,
num_channels=mlp[-1],
)
layers = []
for i in range(len(self.mlp) - 1):
layers.extend(
[
nn.Linear(self.mlp[i], self.mlp[i + 1]),
nn.BatchNorm1d(self.mlp[i + 1]),
nn.ReLU(),
]
)
self.mlp = nn.Sequential(*layers)
self.dropout = nn.Dropout(dropout)
def forward(self, x, batch_positions=None, pad_mask=None, return_comp=False):
sz_b, seq_len, d, h, w = x.shape
if pad_mask is not None:
pad_mask = (
pad_mask.unsqueeze(-1)
.repeat((1, 1, h))
.unsqueeze(-1)
.repeat((1, 1, 1, w))
) # BxTxHxW
pad_mask = (
pad_mask.permute(0, 2, 3, 1).contiguous().view(sz_b * h * w, seq_len)
)
out = x.permute(0, 3, 4, 1, 2).contiguous().view(sz_b * h * w, seq_len, d)
out = self.in_norm(out.permute(0, 2, 1)).permute(0, 2, 1)
if self.inconv is not None:
out = self.inconv(out.permute(0, 2, 1)).permute(0, 2, 1)
if self.positional_encoder is not None:
bp = (
batch_positions.unsqueeze(-1)
.repeat((1, 1, h))
.unsqueeze(-1)
.repeat((1, 1, 1, w))
) # BxTxHxW
bp = bp.permute(0, 2, 3, 1).contiguous().view(sz_b * h * w, seq_len)
out = out + self.positional_encoder(bp)
# re-shaped attn to [h x B*H*W x T], e.g. torch.Size([16, 2048, 4])
# in utae.py this is torch.Size([h, B, T, 32, 32])
# re-shaped output to [h x B*H*W x d_in/h], e.g. torch.Size([16, 2048, 16])
# in utae.py this is torch.Size([B, 128, 32, 32])
out, attn = self.attention_heads(out, pad_mask=pad_mask)
out = (
out.permute(1, 0, 2).contiguous().view(sz_b * h * w, -1)
) # Concatenate heads, out is now [B*H*W x d_in/h * h], e.g. [2048 x 256]
# out is of shape [head x b x t x h x w]
out = self.dropout(self.mlp(out))
# after MLP, out is of shape [B*H*W x outputLayerOfMLP], e.g. [2048 x 128]
out = self.out_norm(out) if self.out_norm is not None else out
out = out.view(sz_b, h, w, -1).permute(0, 3, 1, 2)
attn = attn.view(self.n_head, sz_b, h, w, seq_len).permute(
0, 1, 4, 2, 3
)
# out is of shape [B x outputLayerOfMLP x h x w], e.g. [2, 128, 32, 32]
# attn is of shape [h x B x T x H x W], e.g. [16, 2, 4, 32, 32]
if self.return_att:
return out, attn
else:
return out
class LTAE2dtiny(nn.Module):
def __init__(
self,
in_channels=128,
n_head=16,
d_k=4,
d_model=256,
T=1000,
positional_encoding=True,
):
"""
Lightweight Temporal Attention Encoder (L-TAE) for image time series.
Attention-based sequence encoding that maps a sequence of images to a single feature map.
A shared L-TAE is applied to all pixel positions of the image sequence.
This is the tiny version, which stops further processing attention-weighted values v
(no longer using an MLP) and only returns the attention matrix attn itself
Args:
in_channels (int): Number of channels of the input embeddings.
n_head (int): Number of attention heads.
d_k (int): Dimension of the key and query vectors.
d_model (int, optional): If specified, the input tensors will first processed by a fully connected layer
to project them into a feature space of dimension d_model.
T (int): Period to use for the positional encoding.
positional_encoding (bool): If False, no positional encoding is used (default True).
"""
super(LTAE2dtiny, self).__init__()
self.in_channels = in_channels
self.n_head = n_head
if d_model is not None:
self.d_model = d_model
self.inconv = nn.Conv1d(in_channels, d_model, 1)
else:
self.d_model = in_channels
self.inconv = None
if positional_encoding:
self.positional_encoder = PositionalEncoder(
self.d_model // n_head, T=T, repeat=n_head
)
else:
self.positional_encoder = None
self.attention_heads = MultiHeadAttentionSmall(
n_head=n_head, d_k=d_k, d_in=self.d_model
)
self.in_norm = nn.GroupNorm(
num_groups=n_head,
num_channels=self.in_channels,
)
def forward(self, x, batch_positions=None, pad_mask=None):
sz_b, seq_len, d, h, w = x.shape
if pad_mask is not None:
pad_mask = (
pad_mask.unsqueeze(-1)
.repeat((1, 1, h))
.unsqueeze(-1)
.repeat((1, 1, 1, w))
) # BxTxHxW
pad_mask = (
pad_mask.permute(0, 2, 3, 1).contiguous().view(sz_b * h * w, seq_len)
)
out = x.permute(0, 3, 4, 1, 2).contiguous().view(sz_b * h * w, seq_len, d)
out = self.in_norm(out.permute(0, 2, 1)).permute(0, 2, 1)
if self.inconv is not None:
out = self.inconv(out.permute(0, 2, 1)).permute(0, 2, 1)
if self.positional_encoder is not None:
bp = (
batch_positions.unsqueeze(-1)
.repeat((1, 1, h))
.unsqueeze(-1)
.repeat((1, 1, 1, w))
) # BxTxHxW
bp = bp.permute(0, 2, 3, 1).contiguous().view(sz_b * h * w, seq_len)
out = out + self.positional_encoder(bp)
# re-shaped attn to [h x B*H*W x T], e.g. torch.Size([16, 2048, 4])
# in utae.py this is torch.Size([h, B, T, 32, 32])
# re-shaped output to [h x B*H*W x d_in/h], e.g. torch.Size([16, 2048, 16])
# in utae.py this is torch.Size([B, 128, 32, 32])
attn = self.attention_heads(out, pad_mask=pad_mask)
attn = attn.view(self.n_head, sz_b, h, w, seq_len).permute(
0, 1, 4, 2, 3
)
# out is of shape [B x outputLayerOfMLP x h x w], e.g. [2, 128, 32, 32]
# attn is of shape [h x B x T x H x W], e.g. [16, 2, 4, 32, 32]
return attn
# this class still uses ScaledDotProductAttention (including dropout)
# and always computes and returns att*v
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention module
Modified from github.com/jadore801120/attention-is-all-you-need-pytorch
"""
def __init__(self, n_head, d_k, d_in, use_dropout=True):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_in = d_in # e.g. self.d_model in LTAE2d
# define H x k queries, they are input-independent in LTAE
self.Q = nn.Parameter(torch.zeros((n_head, d_k))).requires_grad_(True)
nn.init.normal_(self.Q, mean=0, std=np.sqrt(2.0 / (d_k)))
self.fc1_k = nn.Linear(d_in, n_head * d_k)
nn.init.normal_(self.fc1_k.weight, mean=0, std=np.sqrt(2.0 / (d_k)))
attn_dropout=0.1 if use_dropout else 0.0
self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5), attn_dropout=attn_dropout)
def forward(self, v, pad_mask=None, return_comp=False):
d_k, d_in, n_head = self.d_k, self.d_in, self.n_head
# values v are of shapes [B*H*W, T, self.d_in=self.d_model], e.g. [2*32*32=2048 x 4 x 256] (see: sz_b * h * w, seq_len, d)
# where self.d_in=self.d_model is the output dimension of the FC-projected features
sz_b, seq_len, _ = v.size()
q = torch.stack([self.Q for _ in range(sz_b)], dim=1).view(-1, d_k) # (n*b) x d_k
k = self.fc1_k(v).view(sz_b, seq_len, n_head, d_k)
k = k.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) # (n*b) x lk x dk
if pad_mask is not None:
pad_mask = pad_mask.repeat(
(n_head, 1)
) # replicate pad_mask for each head (nxb) x lk
# attn is of shape [B*H*W*h, 1, T], e.g. [2*32*32*16=32768 x 1 x 4], e.g. Size([32768, 1, 4])
# v is of shape [B*H*W*h, T, self.d_in/h], e.g. [2*32*32*16=32768 x 4 x 256/16=16], e.g. Size([32768, 4, 16])
# output is of shape [B*H*W*h, 1, h], e.g. [2*32*32*16=32768 x 1 x 16], e.g. Size([32768, 1, 16])
v = torch.stack(v.split(v.shape[-1] // n_head, dim=-1)).view(n_head * sz_b, seq_len, -1)
if return_comp:
output, attn, comp = self.attention(
q, k, v, pad_mask=pad_mask, return_comp=return_comp
)
else:
output, attn = self.attention(
q, k, v, pad_mask=pad_mask, return_comp=return_comp
)
attn = attn.view(n_head, sz_b, 1, seq_len)
attn = attn.squeeze(dim=2)
output = output.view(n_head, sz_b, 1, d_in // n_head)
output = output.squeeze(dim=2)
# re-shaped attn to [h x B*H*W x T], e.g. torch.Size([16, 2048, 4])
# in utae.py this is torch.Size([h, B, T, 32, 32])
# re-shaped output to [h x B*H*W x d_in/h], e.g. torch.Size([16, 2048, 16])
# in utae.py this is torch.Size([B, 128, 32, 32])
if return_comp:
return output, attn, comp
else:
return output, attn
# this class uses ScaledDotProductAttentionSmall (excluding dropout)
# and only optionally computes and returns att*v
class MultiHeadAttentionSmall(nn.Module):
"""Multi-Head Attention module
Modified from github.com/jadore801120/attention-is-all-you-need-pytorch
"""
def __init__(self, n_head, d_k, d_in):
super().__init__()
self.n_head = n_head # e.g. 16
self.d_k = d_k # e.g. 4, number of keys per head
self.d_in = d_in # e.g. 256, self.d_model in LTAE2d
# define H x k queries, they are input-independent in LTAE
self.Q = nn.Parameter(torch.zeros((n_head, d_k))).requires_grad_(True)
nn.init.normal_(self.Q, mean=0, std=np.sqrt(2.0 / (d_k)))
self.fc1_k = nn.Linear(d_in, n_head * d_k)
"""
# consider using deeper mappings with nonlinearities,
# but this is somewhat against the original Transformer spirit
self.fc1_k = nn.Linear(d_in, d_in)
self.bn2_k = nn.BatchNorm1d(d_in)
self.fc2_k = nn.Linear(d_in, n_head * d_k)
self.bn2_k = nn.BatchNorm1d(n_head * d_k)
"""
nn.init.normal_(self.fc1_k.weight, mean=0, std=np.sqrt(2.0 / (d_k)))
#nn.init.normal_(self.fc2_k.weight, mean=0, std=np.sqrt(2.0 / (d_k)))
self.attention = ScaledDotProductAttentionSmall(temperature=np.power(d_k, 0.5))
def forward(self, v, pad_mask=None, return_comp=False, weight_v=False):
d_k, d_in, n_head = self.d_k, self.d_in, self.n_head
# values v are of shapes [B*H*W, T, self.d_in=self.d_model], e.g. [2*32*32=2048 x 4 x 256] (see: sz_b * h * w, seq_len, d)
# where self.d_in=self.d_model is the output dimension of the FC-projected features
sz_b, seq_len, _ = v.size()
q = torch.stack([self.Q for _ in range(sz_b)], dim=1).view(-1, d_k) # (n*b) x d_k
k = self.fc1_k(v).view(sz_b, seq_len, n_head, d_k)
k = k.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) # (n*b) x lk x dk
if pad_mask is not None:
pad_mask = pad_mask.repeat(
(n_head, 1)
) # replicate pad_mask for each head (nxb) x lk
# attn is of shape [B*H*W*h, 1, T], e.g. [2*32*32*16=32768 x 1 x 4], e.g. Size([32768, 1, 4])
# v is of shape [B*H*W*h, T, self.d_in/h], e.g. [2*32*32*16=32768 x 4 x 256/16=16], e.g. Size([32768, 4, 16])
# output is of shape [B*H*W*h, 1, h], e.g. [2*32*32*16=32768 x 1 x 16], e.g. Size([32768, 1, 16])
v = torch.stack(v.split(v.shape[-1] // n_head, dim=-1)).view(n_head * sz_b, seq_len, -1)
if weight_v:
output, attn = self.attention(q, k, v, pad_mask=pad_mask, return_comp=return_comp, weight_v=weight_v)
if return_comp:
output, attn, comp = self.attention(q, k, v, pad_mask=pad_mask, return_comp=return_comp, weight_v=weight_v)
else:
attn = self.attention(q, k, v, pad_mask=pad_mask, return_comp=return_comp, weight_v=weight_v)
attn = attn.view(n_head, sz_b, 1, seq_len)
attn = attn.squeeze(dim=2)
if weight_v:
output = output.view(n_head, sz_b, 1, d_in // n_head)
output = output.squeeze(dim=2)
# re-shaped attn to [h x B*H*W x T], e.g. torch.Size([16, 2048, 4])
# in utae.py this is torch.Size([h, B, T, 32, 32])
# re-shaped output to [h x B*H*W x d_in/h], e.g. torch.Size([16, 2048, 16])
# in utae.py this is torch.Size([B, 128, 32, 32])
if return_comp:
return output, attn, comp
else:
return output, attn
return attn
class ScaledDotProductAttention(nn.Module):
"""Scaled Dot-Product Attention
Modified from github.com/jadore801120/attention-is-all-you-need-pytorch
"""
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, pad_mask=None, return_comp=False):
attn = torch.matmul(q.unsqueeze(1), k.transpose(1, 2))
attn = attn / self.temperature
if pad_mask is not None:
attn = attn.masked_fill(pad_mask.unsqueeze(1), -1e3)
if return_comp:
comp = attn
# attn is of shape [B*H*W*h, 1, T], e.g. [2*32*32*16=32768 x 1 x 4]
# v is of shape [B*H*W*h, T, self.d_in/h], e.g. [2*32*32*16=32768 x 4 x 256/16=16]
# output is of shape [B*H*W*h, 1, h], e.g. [2*32*32*16=32768 x 1 x 16], e.g. Size([32768, 1, 16])
attn = self.softmax(attn)
attn = self.dropout(attn)
output = torch.matmul(attn, v)
if return_comp:
return output, attn, comp
else:
return output, attn
# no longer using dropout (before upsampling)
# but optionally doing attn*v weighting
class ScaledDotProductAttentionSmall(nn.Module):
"""Scaled Dot-Product Attention
Modified from github.com/jadore801120/attention-is-all-you-need-pytorch
"""
def __init__(self, temperature):
super().__init__()
self.temperature = temperature
#self.dropout = nn.Dropout(attn_dropout) # moved dropout after bilinear interpolation
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, pad_mask=None, return_comp=False, weight_v=False):
attn = torch.matmul(q.unsqueeze(1), k.transpose(1, 2))
attn = attn / self.temperature
if pad_mask is not None:
attn = attn.masked_fill(pad_mask.unsqueeze(1), -1e3)
if return_comp:
comp = attn
# attn is of shape [B*H*W*h, 1, T], e.g. [2*32*32*16=32768 x 1 x 4]
# v is of shape [B*H*W*h, T, self.d_in/h], e.g. [2*32*32*16=32768 x 4 x 256/16=16]
# output is of shape [B*H*W*h, 1, h], e.g. [2*32*32*16=32768 x 1 x 16], e.g. Size([32768, 1, 16])
attn = self.softmax(attn)
"""
# no longer using dropout on attention matrices before the upsampling
# this is now done after bilinear interpolation only
attn = self.dropout(attn)
"""
if weight_v:
# optionally using the weighted values
output = torch.matmul(attn, v)
if return_comp:
return output, attn, comp
else:
return output, attn
return attn