yuyan-10b / megatron /model /vision /vit_backbone.py
Shawn001's picture
Upload 131 files
23bd7af
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
import math
import einops
import torch
import apex
import torch.nn.functional as F
from megatron import get_args
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import (
get_linear_layer,
init_method_normal,
scaled_init_method_normal,
)
from megatron.model.module import MegatronModule
CLASS_TOKEN_LENGTH = 8
class VitMlpHead(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, num_classes):
super(VitMlpHead, self).__init__()
self.dense_in = torch.nn.Linear(hidden_size, hidden_size)
self.relu = torch.nn.ReLU()
self.dense_out = torch.nn.Linear(hidden_size, num_classes)
torch.nn.init.constant_(self.dense_out.bias, -10)
def forward(self, hidden_states):
# hidden_states: [b, 1, h]
# sequence_index: index of the token to pool.
dense_in_result = self.dense_in(hidden_states)
tanh_result = torch.tanh(dense_in_result)
dense_out_result = self.dense_out(tanh_result)
return dense_out_result
def isPerfectSquare(x):
if(x >= 0):
sr = math.sqrt(x)
return (int(sr) * int(sr) == x)
return False
def twod_interpolate_position_embeddings_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
args = get_args()
num_patches_per_dim_h = args.img_h // args.patch_dim
num_patches_per_dim_w = args.img_w // args.patch_dim
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
hidden_size = args.hidden_size
key = prefix + "weight"
assert key in state_dict
if key in state_dict:
input_param = state_dict[key]
input_seq_len = input_param.shape[0]
assert(isPerfectSquare(input_seq_len) or isPerfectSquare(input_seq_len - CLASS_TOKEN_LENGTH))
input_has_class_token = not isPerfectSquare(input_seq_len)
num_tok_input = input_seq_len - CLASS_TOKEN_LENGTH if input_has_class_token else input_seq_len
num_tok_output = num_patches
output_has_class_token = args.class_token_present
# update input_param and load it to state_dict[key]
if input_has_class_token:
input_param_tok = input_param[:CLASS_TOKEN_LENGTH, :]
input_param_grid = input_param[CLASS_TOKEN_LENGTH:, :]
else:
input_param_tok = torch.zeros(CLASS_TOKEN_LENGTH, hidden_size)
input_param_grid = input_param
assert input_param.shape[1] == hidden_size
if num_tok_input != num_tok_output:
gs_input = int(math.sqrt(num_tok_input))
gs_new = (num_patches_per_dim_h, num_patches_per_dim_w)
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
input_param_grid = input_param_grid.reshape(
(1, -1, gs_input, gs_input)
)
input_param_grid = input_param_grid.float()
scale_factor = (gs_new[0] / gs_input, gs_new[1] / gs_input)
input_param_grid = F.interpolate(
input_param_grid, scale_factor=scale_factor, mode="bilinear"
)
input_param_grid = input_param_grid.half()
input_param_grid = input_param_grid.reshape((-1, num_tok_output))
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
assert input_param_grid.shape[1] == hidden_size
input_param = input_param_grid
assert (
input_param.shape[0] == num_tok_output
and input_param.shape[1] == hidden_size
)
if output_has_class_token:
input_param = torch.cat((input_param_tok, input_param), dim=0)
state_dict[key] = input_param
class VitBackbone(MegatronModule):
"""Vision Transformer Model."""
def __init__(self,
pre_process=True,
post_process=True,
class_token=True,
single_token_output=False,
post_layer_norm=True,
drop_path_rate=0.0):
super(VitBackbone, self).__init__(share_word_embeddings=False)
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
if args.init_method_xavier_uniform:
self.init_method = torch.nn.init.xavier_uniform_
self.scaled_init_method = torch.nn.init.xavier_uniform_
else:
self.init_method = init_method_normal(args.init_method_std)
self.scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers
)
self.pre_process = pre_process
self.post_process = post_process
self.class_token = class_token
self.post_layer_norm = post_layer_norm
self.hidden_size = args.hidden_size
self.patch_dim = args.patch_dim
self.img_h = args.img_h
self.img_w = args.img_w
self.micro_batch_size = args.micro_batch_size
self.single_token_output = single_token_output
self.drop_path_rate = drop_path_rate
assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0
self.num_patches_per_dim_h = self.img_h // self.patch_dim
self.num_patches_per_dim_w = self.img_w // self.patch_dim
self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w
self.seq_length = self.num_patches + (CLASS_TOKEN_LENGTH if self.class_token else 0)
self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
self.input_tensor = None
self.position_ids = None
if self.pre_process:
# cls_token
if self.class_token:
self.cls_token = torch.nn.Parameter(
torch.randn(1, CLASS_TOKEN_LENGTH, self.hidden_size)
)
torch.nn.init.zeros_(self.cls_token)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
# Linear encoder
self.linear_encoder = torch.nn.Linear(
self.flatten_dim, self.hidden_size
)
# embedding
self.position_embeddings = torch.nn.Embedding(
self.seq_length, self.hidden_size
)
init_method_normal(args.init_method_std)(
self.position_embeddings.weight
)
args.class_token_present = self.class_token
self.position_embeddings._register_load_state_dict_pre_hook(
twod_interpolate_position_embeddings_hook
)
self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
# Transformer
self.transformer = ParallelTransformer(
self.init_method,
self.scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process,
post_layer_norm=self.post_layer_norm,
drop_path_rate=self.drop_path_rate
)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.transformer.set_input_tensor(input_tensor)
def forward(self, input):
if self.pre_process:
rearranged_input = einops.rearrange(
input,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=self.patch_dim,
p2=self.patch_dim,
)
assert rearranged_input.dtype == torch.half
encoder_output = self.linear_encoder(rearranged_input)
concatenated_tokens = encoder_output
if self.class_token:
cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)
token_embeddings = concatenated_tokens + \
self.position_embeddings(self.position_ids[:, :concatenated_tokens.shape[1]])
hidden_states = self.embedding_dropout(token_embeddings)
else:
hidden_states = input
hidden_states = self.transformer(hidden_states, None)
if self.single_token_output:
hidden_states = hidden_states[:,0,:]
return hidden_states