herrius's picture
Upload 259 files
32b542e
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import torch
import copy
from .gate import one_hot_with_dtype
from uniperceiver.utils import comm
import torch.nn.functional as F
from torch.cuda.amp import autocast
class FusedExperts(torch.nn.Module):
def __init__(self, expert, cfg, num_local_experts=1):
super(FusedExperts, self).__init__()
self.cfg = cfg
self.deepspeed_experts = torch.nn.ModuleList(
[copy.deepcopy(expert) for i in range(num_local_experts)])
self.num_local_experts = num_local_experts
self.bias_merge = self.deepspeed_experts[0].bias is not None
def top1_expert_forward(self, x, indice, gate, mode=None, **kwargs):
assert mode is None, "unified qkv inference is not supported for top1"
if indice.size(0)== 1:
#unimodal
x = self.deepspeed_experts[indice[0]](x) * gate[0].to(x)
elif indice.size(0) == 2:
# mulmodal
data1_length = kwargs['sample_info']['data_cum_length'][1]
x = torch.cat([
self.deepspeed_experts[indice[0]](x[:, :data1_length, :]) * gate[0].to(x),
self.deepspeed_experts[indice[1]](x[:, data1_length:, :]) * gate[1].to(x)
],
dim=1)
else:
raise NotImplementedError('only support one or two modality')
return x
def mergelayer(self, x, index1, index2, gate1, gate2, mode=None):
if not self.cfg.SOLVER.FORCE_EXPERT_ADDING_FP16:
if mode == 'q':
# hidden_states
_start = 0
_end = self.deepspeed_experts[index1].weight.shape[0] // 3
return F.linear(
x,
self.deepspeed_experts[index1].weight[_start:_end, :] * gate1 +
self.deepspeed_experts[index2].weight[_start:_end, :] * gate2,
bias=self.deepspeed_experts[index1].bias[_start:_end] * gate1 +
self.deepspeed_experts[index2].bias[_start:_end] * gate2
if self.bias_merge else None,
)
elif mode == 'kv':
# history_states
_start = self.deepspeed_experts[index1].weight.shape[0] // 3
return F.linear(
x,
self.deepspeed_experts[index1].weight[_start:, :] * gate1 +
self.deepspeed_experts[index2].weight[_start:, :] * gate2,
bias=self.deepspeed_experts[index1].bias[_start:] * gate1 +
self.deepspeed_experts[index2].bias[_start:] * gate2
if self.bias_merge else None,
)
else:
return F.linear(
x,
self.deepspeed_experts[index1].weight * gate1 +
self.deepspeed_experts[index2].weight * gate2,
bias=self.deepspeed_experts[index1].bias * gate1 +
self.deepspeed_experts[index2].bias * gate2 if self.bias_merge else None,
)
else:
if mode == 'q':
# hidden_states
_start = 0
_end = self.deepspeed_experts[index1].weight.shape[0] // 3
return F.linear(
x,
self.deepspeed_experts[index1].weight[_start:_end, :].half() * gate1 +
self.deepspeed_experts[index2].weight[_start:_end, :].half() * gate2,
bias=self.deepspeed_experts[index1].bias[_start:_end].half() * gate1 +
self.deepspeed_experts[index2].bias[_start:_end].half() * gate2 if self.bias_merge else None,
)
elif mode == 'kv':
# history_states
_start = self.deepspeed_experts[index1].weight.shape[0] // 3
return F.linear(
x,
self.deepspeed_experts[index1].weight[_start:, :].half() * gate1 +
self.deepspeed_experts[index2].weight[_start:, :].half() * gate2,
bias=self.deepspeed_experts[index1].bias[_start:].half() * gate1 +
self.deepspeed_experts[index2].bias[_start:].half() * gate2 if self.bias_merge else None,
)
else:
return F.linear(
x,
self.deepspeed_experts[index1].weight.half() * gate1 + self.deepspeed_experts[index2].weight.half() * gate2,
bias=self.deepspeed_experts[index1].bias.half() * gate1 +
self.deepspeed_experts[index2].bias.half() * gate2 if self.bias_merge else None,
)
def top2_expert_forward(self, x, indices, gates, mode=None, **kwargs):
# caption eval mode
if comm._CAPTION_GEN_MODE and x.shape[1] == 1:
#
return self.mergelayer(x,
indices[0][1], indices[1][1],
gates[0][1], gates[1][1], mode=mode)
# unimodal
if indices[0].size(0) == 1:
x = self.mergelayer(x, indices[0][0], indices[1][0], gates[0][0], gates[1][0], mode=mode)
elif indices[0].size(0) == 2:
data1_length = kwargs['sample_info']['data_cum_length'][1]
if mode == 'kv' and kwargs['sample_info'].get('pe_length', 0) > 0:
# may have prompt embedding for kv embedding
data1_length += kwargs['sample_info'].get('pe_length', 0)
x = torch.cat([
self.mergelayer(x[:, :data1_length, :], indices[0][0], indices[1][0], gates[0][0], gates[1][0], mode=mode),
self.mergelayer(x[:, data1_length:, :], indices[0][1], indices[1][1], gates[0][1], gates[1][1], mode=mode)
],
dim=1)
else:
raise NotImplementedError('only support one or two modality')
return x
def forward(self, hidden_states, top_indices=None, gates=None, **kwargs):
# top1
if len(top_indices) == 1:
out = self.top1_expert_forward(hidden_states, top_indices[0], gates[0], **kwargs)
# top2
elif len(top_indices) == 2:
out = self.top2_expert_forward(hidden_states, top_indices, gates, **kwargs)
else:
raise NotImplementedError("only support top1 and top2 ")
assert out.shape[1] == hidden_states.shape[1]
return out