herrius's picture
Upload 259 files
32b542e
'''
Copyright 2021 The Microsoft DeepSpeed Team
'''
# The file has been adapted from two fairscale files:
# (1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py
# (2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py
# Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf
# We retain the following license from the original files:
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union, cast
import time
from time import perf_counter
import torch
from torch import nn
from torch import Tensor
import torch.distributed as dist
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from uniperceiver.utils.events import get_event_storage
from torch.cuda.amp import autocast
if TYPE_CHECKING:
Base = Module[Tensor]
else:
Base = Module
uniform_map: Dict[torch.device, Callable] = {}
gumbel_map: Dict[torch.device, Callable] = {}
normal_map: Dict[torch.device, Callable] = {}
exp_selection_uniform_map: Dict[torch.device, Callable] = {}
import torch.distributed.nn
from uniperceiver.utils import comm
from uniperceiver.modeling.layers import FP16LayerNorm
def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
"""
Modified from switch transformer paper. mesh transformers
Multiply values by a random number between 1-epsilon and 1+epsilon.
Makes models more resilient to rounding errors introduced by bfloat16.
This seems particularly important for logits.
Args:
x: a torch.tensor
device: torch.device
epsilon: a floating point value
Returns:
a jittered x.
"""
if epsilon == 0:
return x
uniform = uniform_map.get(device)
if uniform is None:
uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(1.0 - epsilon, device=device),
high=torch.tensor(1.0 + epsilon,
device=device)).rsample # type: ignore
uniform_map[device] = uniform
return x * uniform(x.shape)
def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
gumbel = gumbel_map.get(device)
if gumbel is None:
one = torch.tensor(1.0, device=device)
zero = torch.tensor(0.0, device=device)
gumbel = torch.distributions.gumbel.Gumbel(zero,
one).rsample # type: ignore
gumbel_map[device] = gumbel
return gumbel(shape)
def normal_rsample(shape: Tuple, device: torch.device, num_expert: int) -> Tensor:
normal = normal_map.get(device)
if normal is None:
std = torch.tensor(1.0/num_expert, device=device)
mean = torch.tensor(0.0, device=device)
normal = torch.distributions.normal.Normal(mean, std).rsample # type: ignore
normal_map[device] = normal
return normal(shape)
def one_hot_with_dtype(data, num_classes, dtype):
result = torch.zeros([data.size(0), num_classes],
device=data.device,
dtype=dtype)
result.scatter_(1, data.unsqueeze(-1), 1)
return result
@torch.jit.script
def _top_idx(source, k):
return torch.topk(source, k=k, dim=0)[1]
@torch.jit.script
def _one_hot_to_float(x, num_classes):
return F.one_hot(x, num_classes=num_classes).float()
class TopKGate(nn.Module):
"""Gate module which implements Top2Gating as described in Gshard_.
::
gate = TopKGate(model_dim, num_experts)
l_aux, combine_weights, dispatch_mask = gate(input)
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
Args:
model_dim (int):
size of model embedding dimension
num_experts (ints):
number of experts in model
"""
# wg: torch.nn.Linear
def __init__(self,
model_dim: int,
num_experts: int,
k: int = 1,
noisy_gate_policy: Optional[str] = None,
cfg: dict = None,
moe_type: str = None,
**kwargs):
super().__init__( )
if k != 1 and k != 2:
raise ValueError('Only top-1 and top-2 gatings are supported.')
self.model_dim = model_dim
self.k = k
self.cfg = cfg
self.noisy_gate_policy = noisy_gate_policy
self.noise_std = self.cfg.MOE.NOISE_STD
self.batch_prioritized_routing = self.cfg.MOE.BATCH_PRIO
self.gate = self.cfg.MOE.GATE_TYPE
self.layer_type = kwargs.pop('moe_type', 'ffn')
self.tag_transform_enable = self.cfg.MOE.TAG_Transform
self.moe_type = moe_type
if self.cfg.SOLVER.FORCE_LN_FP16:
LayerNormModule = FP16LayerNorm
else:
LayerNormModule = torch.nn.LayerNorm
if self.tag_transform_enable and self.cfg.MOE.TAG_Transform_ACT:
self.tag_transform = torch.nn.Sequential(torch.nn.Linear(self.cfg.MOE.ATTRIBUTE_LENGTH, self.model_dim), torch.nn.GELU(),
LayerNormModule(self.model_dim))
else:
self.tag_transform = torch.nn.Sequential(torch.nn.Linear(self.cfg.MOE.ATTRIBUTE_LENGTH, self.model_dim), LayerNormModule(self.model_dim))
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
pass
def tag_gate(self, x, data_type=None, moe_embedding:torch.Tensor = None, **kwargs):
if self.cfg.MODEL.TAG_TRANSFORM_FP32:
with autocast(enabled=False):
gate_embed = self.tag_transform.float()(moe_embedding.float())
else:
gate_embed = self.tag_transform(moe_embedding)
return gate_embed
def forward(
self,
input,
**kwargs,
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
if self.tag_transform_enable:
input = self.tag_gate(input, **kwargs)
if self.wg.weight.dtype != torch.float32:
self.wg = self.wg.float()
input_fp32 = input.float()
# input jittering
if self.noisy_gate_policy == 'Jitter' and self.training:
input_fp32 = multiplicative_jitter(input_fp32, device=input.device)
with autocast(enabled=not self.cfg.MODEL.GATE_FP32):
if self.cfg.SOLVER.FORCE_WG_RECAST:
# used for dbeugging only
logits = self.wg.half().float()(input_fp32)
else:
logits = self.wg(input_fp32)
if self.k == 1 and self.gate == 'deepspeed':
gate_output = self.top1gating(
logits,
self.noisy_gate_policy if self.training else None,
**kwargs)
# tutel gate function
else:
gate_output = self.top2gating(
logits,
self.noisy_gate_policy if self.training else None,
**kwargs )
return gate_output
def load_balance(self, gates, mask1, num_experts, data_type=None):
# Compute l_aux
if self.balance_loss and self.training:
# TODO: for retrieval task, these maybe some gpu do not have this input
if data_type == 'INPUT':
if comm._LOCAL_IMAGE_LENGTH > 0 and not comm._LOCAL_UTOKEN_LENGTH + comm._LOCAL_GTOKEN_LENGTH > 0:
# input image features only
me = gates.sum(dim=0)
ce = mask1.sum(dim=0)
# maybe has retrieval pair
if comm._MOE_TARGET_MECE_LIST.get(str(comm._LOCAL_CURRENT_LAYER)+'_'+self.layer_type, None) is not None:
# if len(comm._MOE_TARGET_MECE_LIST) > 0:
me_t, ce_t = comm._MOE_TARGET_MECE_LIST[
str(comm._LOCAL_CURRENT_LAYER) + '_' +
self.layer_type]
me = me + me_t
ce = ce + ce_t
me = me * self.task_weights[comm._LOCAL_CURRENT_TASK]
ce = ce * self.task_weights[comm._LOCAL_CURRENT_TASK]
elif comm._LOCAL_IMAGE_LENGTH > 0 and comm._LOCAL_UTOKEN_LENGTH + comm._LOCAL_GTOKEN_LENGTH > 0:
# sum of these two distribution from two modalities
me = gates.sum(dim=0)
ce = mask1.sum(dim=0)
me = me * self.task_weights[comm._LOCAL_CURRENT_TASK]
ce = ce * self.task_weights[comm._LOCAL_CURRENT_TASK]
elif comm._LOCAL_IMAGE_LENGTH <= 0 and comm._LOCAL_UTOKEN_LENGTH + comm._LOCAL_GTOKEN_LENGTH > 0:
me = gates.sum(
dim=0) * self.task_weights[comm._LOCAL_CURRENT_TASK]
ce = mask1.sum(
dim=0) * self.task_weights[comm._LOCAL_CURRENT_TASK]
# raise NotImplementedError
else:
raise NotImplementedError
elif data_type == 'TARGET':
# the retrieval embedding
# only remove the padding contributions
comm._MOE_TARGET_MECE_LIST[str(comm._LOCAL_CURRENT_LAYER) + '_' +self.layer_type] = [gates.sum(dim=0), mask1.sum(dim=0)]
elif data_type == 'IN_LABEL':
# remove paddings contributions
me = gates.sum(dim=0)
ce = mask1.sum(dim=0)
elif data_type == 'WORD_VOCAB':
# do not need padding mask
me = gates.sum(dim=0)
ce = mask1.sum(dim=0)
else:
raise NotImplementedError
# debug left
if not data_type == 'TARGET':
me = torch.distributed.nn.all_reduce(
me) / comm.get_world_size()
ce = torch.distributed.nn.all_reduce(
ce) / comm.get_world_size()
if data_type not in comm._MOE_LOSSES_COLLECTIONS[
'exp_balance']:
comm._MOE_LOSSES_COLLECTIONS['exp_balance'][
data_type] = []
comm._MOE_LOSSES_COLLECTIONS['exp_balance'][
data_type].append([me, ce])
def top1gating(
self,
logits: Tensor,
noisy_gate_policy: Optional[str] = None,
**kwargs,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top1Gating on logits."""
logits_w_noise = None
if noisy_gate_policy == 'RSample':
logits_w_noise = logits + gumbel_rsample(logits.shape,
device=logits.device)
elif noisy_gate_policy == 'vmoe':
num_experts = int(logits.shape[-1])
logits_w_noise = logits + normal_rsample(logits.shape,
device=logits.device,
num_expert=num_experts/self.noise_std)
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)
# Create a mask for 1st's expert per token
# noisy gating
indices1_s = torch.argmax(logits_w_noise if logits_w_noise is not None else gates, dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
# gating decisions
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
self.load_balance(gates, mask1, num_experts)
self.tb_output(
mask1,
exp_counts,
gates=None
)
gates = (gates*mask1).sum(dim=1)
self.tb_output(mask1=None, exp_counts=None, gates=[gates])
return [indices1_s], [gates]
def top2gating(
self,
logits: Tensor,
noisy_gate_policy: Optional[str] = None,
**kwargs,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
num_experts = int(logits.shape[-1])
logits_w_noise = None
if noisy_gate_policy == 'RSample':
logits_w_noise = logits + gumbel_rsample(logits.shape,
device=logits.device) * self.noise_std
elif noisy_gate_policy == 'vmoe':
logits_w_noise = logits + normal_rsample(logits.shape,
device=logits.device,
num_expert=num_experts/self.noise_std)
# topk_indices = torch.topk(logits, self.k, dim=1).indices
topk_indices = torch.topk(
logits_w_noise
if logits_w_noise is not None else logits,
self.k,
dim=1).indices
indices_s = [x.view(-1) for x in topk_indices.chunk(self.k, dim=1)]
masks_se = [
one_hot_with_dtype(x, num_classes=num_experts, dtype=x.dtype)
for x in indices_s
]
if noisy_gate_policy == 'vmoe':
gates = F.softmax(logits_w_noise, dim=1)
else:
gates = F.softmax(logits, dim=1)
# self.load_balance(gates, masks_se[0], num_experts)
gates_s = [(gates * x).sum(dim=1) for x in masks_se]
# gating decisions
exp_counts = torch.sum(masks_se[0], dim=0).detach().to('cpu')
# self.tb_output(masks_se[0], exp_counts, gates=None)
# if self.k>1:
# for k in range(1, self.k):
# self.tb_output(masks_se[k], torch.sum(masks_se[k], dim=0).detach().to('cpu'), None, postfix='_top{}'.format(k+1))
if self.k > 1:
# Normalize Gate
denom_s = torch.clamp(sum(gates_s),
min=torch.finfo(gates_s[0].dtype).eps)
gates_s = [x / denom_s for x in gates_s]
# self.tb_output(mask1=None, exp_counts=None, gates=gates_s)
return indices_s, gates_s
def tb_output(self, data_type=None, mask1=None, exp_counts=None, gates=None, postfix=''):
if self.training:
storage = get_event_storage()
else:
return
if not (comm._LOCAL_CURRENT_TASK == 'imagenet' or comm._LOCAL_CURRENT_TASK.startswith('bookswiki') or comm._LOCAL_CURRENT_TASK.startswith('cc3m') or comm._LOCAL_CURRENT_TASK.startswith('cc12m') or comm._LOCAL_CURRENT_TASK.startswith('tqa')):
# to save time
return
if (storage._iter+1)%(comm._EXPERT_LOG_INTERVAL//10) != 0:
# to save time
return
if storage is not None and comm.is_main_process():
# pass
# for each expert
if gates is not None:
if data_type == "INPUT" and comm._LOCAL_IMAGE_LENGTH > 0:
gate_logs = {
"logits_layer{}_expert_{}/top{}_{}_{}_v".format(
comm._LOCAL_CURRENT_LAYER, self.layer_type,
e_id+1, comm._LOCAL_CURRENT_TASK,
data_type): ratio[0]
for e_id, ratio in enumerate(gates)
}
storage.put_scalars(**gate_logs, avg_hint=True)
if gates[0].shape[0] > 1:
gates_t_logs = {
"logits_layer{}_expert_{}/top{}_{}_{}_t".
format(comm._LOCAL_CURRENT_LAYER,
self.layer_type, e_id+1,
comm._LOCAL_CURRENT_TASK,
data_type): ratio[1]
for e_id, ratio in enumerate(gates)
}
storage.put_scalars(**gates_t_logs, avg_hint=True)
elif data_type in ['IN_LABEL', 'WORD_VOCAB']:
gates_logs = {
"logits_layer{}_expert_{}/top{}_{}".format(
comm._LOCAL_CURRENT_LAYER, self.layer_type,
e_id+1, data_type): ratio[0]
for e_id, ratio in enumerate(gates)
}
storage.put_scalars(**gates_logs, avg_hint=True)
else:
gates_logs = {
"layer{}_expert_{}/top{}_{}_{}".format(
comm._LOCAL_CURRENT_LAYER, self.layer_type,
e_id+1, comm._LOCAL_CURRENT_TASK,
data_type): ratio[0]
for e_id, ratio in enumerate(gates)
}
storage.put_scalars(**gates_logs, avg_hint=True)
else:
if data_type == "INPUT" and comm._LOCAL_IMAGE_LENGTH > 0:
exp_counts_v = mask1[0]
exp_count_logs = {
"layer{}_expert_{}/E{}_{}_{}_v{}".format(
comm._LOCAL_CURRENT_LAYER, self.layer_type, e_id,
comm._LOCAL_CURRENT_TASK, data_type,
postfix): ratio
for e_id, ratio in enumerate((exp_counts_v /
exp_counts_v.sum()).tolist())
}
storage.put_scalars(**exp_count_logs, avg_hint=True)
if mask1.size(0)>1:
exp_counts_t = mask1[1]
exp_count_logs = {
"layer{}_expert_{}/E{}_{}_{}_t{}".format(
comm._LOCAL_CURRENT_LAYER, self.layer_type, e_id,
comm._LOCAL_CURRENT_TASK,
data_type, postfix): ratio
for e_id, ratio in enumerate((
exp_counts_t / exp_counts_t.sum()).tolist())
}
storage.put_scalars(**exp_count_logs, avg_hint=True)
elif data_type in ['IN_LABEL', 'WORD_VOCAB']:
exp_count_logs = {
"layer{}_expert_{}/E{}_{}{}".format(
comm._LOCAL_CURRENT_LAYER, self.layer_type, e_id,
data_type, postfix): ratio
for e_id, ratio in enumerate((exp_counts /
exp_counts.sum()).tolist())
}
storage.put_scalars(**exp_count_logs, avg_hint=True)
else:
exp_count_logs = {
"layer{}_expert_{}/E{}_{}_{}{}".format(
comm._LOCAL_CURRENT_LAYER, self.layer_type, e_id,
comm._LOCAL_CURRENT_TASK, data_type,
postfix): ratio
for e_id, ratio in enumerate((exp_counts /
exp_counts.sum()).tolist())
}
storage.put_scalars(**exp_count_logs, avg_hint=True)