herrius's picture
Upload 259 files
32b542e
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import torch.nn.init as init
import torch
from torch import nn
import torch.distributed as dist
from .gate import TopKGate
import copy
import typing
from .experts import FusedExperts as Experts
class TaskMoE(torch.nn.Module):
def __init__(self,
hidden_size,
expert,
num_experts=1,
k=1,
capacity_factor=1.,
eval_capacity_factor=1.,
min_capacity=4,
noisy_gate_policy: typing.Optional[str] = None,
drop_tokens: bool = True,
use_rts=True,
use_tutel: bool = False,
cfg=None):
"""Initialize an MoE layer.
Arguments:
hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.
expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).
num_experts (int, optional): default=1, the total number of experts per layer.
k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
use_rts (bool, optional): default=True, whether to use Random Token Selection.
use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
"""
super().__init__()
self.num_experts = num_experts
if isinstance(expert, nn.Linear):
self.expert_type = 'linear'
elif isinstance(expert, nn.MultiheadAttention):
self.expert_type = 'attention'
else:
raise NotImplementedError('please check expert type')
experts = Experts(expert, cfg, num_experts)
self.gate = TopKGate(hidden_size,
num_experts,
k,
noisy_gate_policy,
cfg,
moe_type=self.expert_type)
self.experts = experts
def forward(self, hidden_states, gate_decision=None, **kwargs):
""" MoE forward
Arguments:
hidden_states (Tensor): input to the layer
Returns:
A tuple including output
* output (Tensor): output of the model
"""
if gate_decision is not None:
top_indices, gates = gate_decision
else:
top_indices, gates = self.gate(hidden_states, **kwargs)
expert_output = self.experts(hidden_states, top_indices, gates, **kwargs)
return expert_output, [top_indices, gates]