|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Vision Transformer(VIT) model.""" |
|
|
|
import math |
|
import einops |
|
import torch |
|
import apex |
|
import torch.nn.functional as F |
|
from megatron import get_args |
|
from megatron.model.transformer import ParallelTransformer |
|
from megatron.model.utils import ( |
|
get_linear_layer, |
|
init_method_normal, |
|
scaled_init_method_normal, |
|
) |
|
from megatron.model.module import MegatronModule |
|
|
|
CLASS_TOKEN_LENGTH = 8 |
|
|
|
class VitMlpHead(MegatronModule): |
|
"""Pooler layer. |
|
|
|
Pool hidden states of a specific token (for example start of the |
|
sequence) and add a linear transformation followed by a tanh. |
|
|
|
Arguments: |
|
hidden_size: hidden size |
|
init_method: weight initialization method for the linear layer. |
|
bias is set to zero. |
|
""" |
|
|
|
def __init__(self, hidden_size, num_classes): |
|
super(VitMlpHead, self).__init__() |
|
self.dense_in = torch.nn.Linear(hidden_size, hidden_size) |
|
self.relu = torch.nn.ReLU() |
|
self.dense_out = torch.nn.Linear(hidden_size, num_classes) |
|
torch.nn.init.constant_(self.dense_out.bias, -10) |
|
|
|
def forward(self, hidden_states): |
|
|
|
|
|
dense_in_result = self.dense_in(hidden_states) |
|
tanh_result = torch.tanh(dense_in_result) |
|
dense_out_result = self.dense_out(tanh_result) |
|
return dense_out_result |
|
|
|
|
|
def isPerfectSquare(x): |
|
if(x >= 0): |
|
sr = math.sqrt(x) |
|
return (int(sr) * int(sr) == x) |
|
return False |
|
|
|
|
|
def twod_interpolate_position_embeddings_hook( |
|
state_dict, |
|
prefix, |
|
local_metadata, |
|
strict, |
|
missing_keys, |
|
unexpected_keys, |
|
error_msgs, |
|
): |
|
|
|
args = get_args() |
|
num_patches_per_dim_h = args.img_h // args.patch_dim |
|
num_patches_per_dim_w = args.img_w // args.patch_dim |
|
num_patches = num_patches_per_dim_h * num_patches_per_dim_w |
|
hidden_size = args.hidden_size |
|
|
|
key = prefix + "weight" |
|
|
|
assert key in state_dict |
|
if key in state_dict: |
|
input_param = state_dict[key] |
|
|
|
input_seq_len = input_param.shape[0] |
|
assert(isPerfectSquare(input_seq_len) or isPerfectSquare(input_seq_len - CLASS_TOKEN_LENGTH)) |
|
input_has_class_token = not isPerfectSquare(input_seq_len) |
|
num_tok_input = input_seq_len - CLASS_TOKEN_LENGTH if input_has_class_token else input_seq_len |
|
num_tok_output = num_patches |
|
output_has_class_token = args.class_token_present |
|
|
|
|
|
if input_has_class_token: |
|
input_param_tok = input_param[:CLASS_TOKEN_LENGTH, :] |
|
input_param_grid = input_param[CLASS_TOKEN_LENGTH:, :] |
|
else: |
|
input_param_tok = torch.zeros(CLASS_TOKEN_LENGTH, hidden_size) |
|
input_param_grid = input_param |
|
|
|
assert input_param.shape[1] == hidden_size |
|
|
|
if num_tok_input != num_tok_output: |
|
|
|
gs_input = int(math.sqrt(num_tok_input)) |
|
gs_new = (num_patches_per_dim_h, num_patches_per_dim_w) |
|
|
|
input_param_grid = input_param_grid.transpose(0, 1).contiguous() |
|
input_param_grid = input_param_grid.reshape( |
|
(1, -1, gs_input, gs_input) |
|
) |
|
input_param_grid = input_param_grid.float() |
|
scale_factor = (gs_new[0] / gs_input, gs_new[1] / gs_input) |
|
|
|
input_param_grid = F.interpolate( |
|
input_param_grid, scale_factor=scale_factor, mode="bilinear" |
|
) |
|
|
|
input_param_grid = input_param_grid.half() |
|
input_param_grid = input_param_grid.reshape((-1, num_tok_output)) |
|
input_param_grid = input_param_grid.transpose(0, 1).contiguous() |
|
|
|
assert input_param_grid.shape[1] == hidden_size |
|
|
|
input_param = input_param_grid |
|
assert ( |
|
input_param.shape[0] == num_tok_output |
|
and input_param.shape[1] == hidden_size |
|
) |
|
|
|
if output_has_class_token: |
|
input_param = torch.cat((input_param_tok, input_param), dim=0) |
|
|
|
state_dict[key] = input_param |
|
|
|
|
|
class VitBackbone(MegatronModule): |
|
"""Vision Transformer Model.""" |
|
|
|
def __init__(self, |
|
pre_process=True, |
|
post_process=True, |
|
class_token=True, |
|
single_token_output=False, |
|
post_layer_norm=True, |
|
drop_path_rate=0.0): |
|
super(VitBackbone, self).__init__(share_word_embeddings=False) |
|
args = get_args() |
|
|
|
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy |
|
if args.init_method_xavier_uniform: |
|
self.init_method = torch.nn.init.xavier_uniform_ |
|
self.scaled_init_method = torch.nn.init.xavier_uniform_ |
|
else: |
|
self.init_method = init_method_normal(args.init_method_std) |
|
self.scaled_init_method = scaled_init_method_normal( |
|
args.init_method_std, args.num_layers |
|
) |
|
|
|
self.pre_process = pre_process |
|
self.post_process = post_process |
|
self.class_token = class_token |
|
self.post_layer_norm = post_layer_norm |
|
self.hidden_size = args.hidden_size |
|
self.patch_dim = args.patch_dim |
|
self.img_h = args.img_h |
|
self.img_w = args.img_w |
|
self.micro_batch_size = args.micro_batch_size |
|
self.single_token_output = single_token_output |
|
self.drop_path_rate = drop_path_rate |
|
|
|
assert self.img_h % self.patch_dim == 0 |
|
assert self.img_w % self.patch_dim == 0 |
|
self.num_patches_per_dim_h = self.img_h // self.patch_dim |
|
self.num_patches_per_dim_w = self.img_w // self.patch_dim |
|
self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w |
|
self.seq_length = self.num_patches + (CLASS_TOKEN_LENGTH if self.class_token else 0) |
|
self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels |
|
self.input_tensor = None |
|
self.position_ids = None |
|
|
|
if self.pre_process: |
|
|
|
if self.class_token: |
|
self.cls_token = torch.nn.Parameter( |
|
torch.randn(1, CLASS_TOKEN_LENGTH, self.hidden_size) |
|
) |
|
torch.nn.init.zeros_(self.cls_token) |
|
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda() |
|
|
|
|
|
self.linear_encoder = torch.nn.Linear( |
|
self.flatten_dim, self.hidden_size |
|
) |
|
|
|
|
|
self.position_embeddings = torch.nn.Embedding( |
|
self.seq_length, self.hidden_size |
|
) |
|
init_method_normal(args.init_method_std)( |
|
self.position_embeddings.weight |
|
) |
|
|
|
args.class_token_present = self.class_token |
|
self.position_embeddings._register_load_state_dict_pre_hook( |
|
twod_interpolate_position_embeddings_hook |
|
) |
|
|
|
self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout) |
|
|
|
|
|
self.transformer = ParallelTransformer( |
|
self.init_method, |
|
self.scaled_init_method, |
|
pre_process=self.pre_process, |
|
post_process=self.post_process, |
|
post_layer_norm=self.post_layer_norm, |
|
drop_path_rate=self.drop_path_rate |
|
) |
|
|
|
def set_input_tensor(self, input_tensor): |
|
"""See megatron.model.transformer.set_input_tensor()""" |
|
self.transformer.set_input_tensor(input_tensor) |
|
|
|
def forward(self, input): |
|
|
|
if self.pre_process: |
|
rearranged_input = einops.rearrange( |
|
input, |
|
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)", |
|
p1=self.patch_dim, |
|
p2=self.patch_dim, |
|
) |
|
|
|
assert rearranged_input.dtype == torch.half |
|
encoder_output = self.linear_encoder(rearranged_input) |
|
|
|
concatenated_tokens = encoder_output |
|
if self.class_token: |
|
cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1) |
|
concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1) |
|
|
|
token_embeddings = concatenated_tokens + \ |
|
self.position_embeddings(self.position_ids[:, :concatenated_tokens.shape[1]]) |
|
hidden_states = self.embedding_dropout(token_embeddings) |
|
else: |
|
hidden_states = input |
|
|
|
hidden_states = self.transformer(hidden_states, None) |
|
|
|
if self.single_token_output: |
|
hidden_states = hidden_states[:,0,:] |
|
|
|
return hidden_states |
|
|
|
|