File size: 6,545 Bytes
32b542e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import torch
import copy
from .gate import one_hot_with_dtype
from uniperceiver.utils import comm
import torch.nn.functional as F
from torch.cuda.amp import autocast
class FusedExperts(torch.nn.Module):
def __init__(self, expert, cfg, num_local_experts=1):
super(FusedExperts, self).__init__()
self.cfg = cfg
self.deepspeed_experts = torch.nn.ModuleList(
[copy.deepcopy(expert) for i in range(num_local_experts)])
self.num_local_experts = num_local_experts
self.bias_merge = self.deepspeed_experts[0].bias is not None
def top1_expert_forward(self, x, indice, gate, mode=None, **kwargs):
assert mode is None, "unified qkv inference is not supported for top1"
if indice.size(0)== 1:
#unimodal
x = self.deepspeed_experts[indice[0]](x) * gate[0].to(x)
elif indice.size(0) == 2:
# mulmodal
data1_length = kwargs['sample_info']['data_cum_length'][1]
x = torch.cat([
self.deepspeed_experts[indice[0]](x[:, :data1_length, :]) * gate[0].to(x),
self.deepspeed_experts[indice[1]](x[:, data1_length:, :]) * gate[1].to(x)
],
dim=1)
else:
raise NotImplementedError('only support one or two modality')
return x
def mergelayer(self, x, index1, index2, gate1, gate2, mode=None):
if not self.cfg.SOLVER.FORCE_EXPERT_ADDING_FP16:
if mode == 'q':
# hidden_states
_start = 0
_end = self.deepspeed_experts[index1].weight.shape[0] // 3
return F.linear(
x,
self.deepspeed_experts[index1].weight[_start:_end, :] * gate1 +
self.deepspeed_experts[index2].weight[_start:_end, :] * gate2,
bias=self.deepspeed_experts[index1].bias[_start:_end] * gate1 +
self.deepspeed_experts[index2].bias[_start:_end] * gate2
if self.bias_merge else None,
)
elif mode == 'kv':
# history_states
_start = self.deepspeed_experts[index1].weight.shape[0] // 3
return F.linear(
x,
self.deepspeed_experts[index1].weight[_start:, :] * gate1 +
self.deepspeed_experts[index2].weight[_start:, :] * gate2,
bias=self.deepspeed_experts[index1].bias[_start:] * gate1 +
self.deepspeed_experts[index2].bias[_start:] * gate2
if self.bias_merge else None,
)
else:
return F.linear(
x,
self.deepspeed_experts[index1].weight * gate1 +
self.deepspeed_experts[index2].weight * gate2,
bias=self.deepspeed_experts[index1].bias * gate1 +
self.deepspeed_experts[index2].bias * gate2 if self.bias_merge else None,
)
else:
if mode == 'q':
# hidden_states
_start = 0
_end = self.deepspeed_experts[index1].weight.shape[0] // 3
return F.linear(
x,
self.deepspeed_experts[index1].weight[_start:_end, :].half() * gate1 +
self.deepspeed_experts[index2].weight[_start:_end, :].half() * gate2,
bias=self.deepspeed_experts[index1].bias[_start:_end].half() * gate1 +
self.deepspeed_experts[index2].bias[_start:_end].half() * gate2 if self.bias_merge else None,
)
elif mode == 'kv':
# history_states
_start = self.deepspeed_experts[index1].weight.shape[0] // 3
return F.linear(
x,
self.deepspeed_experts[index1].weight[_start:, :].half() * gate1 +
self.deepspeed_experts[index2].weight[_start:, :].half() * gate2,
bias=self.deepspeed_experts[index1].bias[_start:].half() * gate1 +
self.deepspeed_experts[index2].bias[_start:].half() * gate2 if self.bias_merge else None,
)
else:
return F.linear(
x,
self.deepspeed_experts[index1].weight.half() * gate1 + self.deepspeed_experts[index2].weight.half() * gate2,
bias=self.deepspeed_experts[index1].bias.half() * gate1 +
self.deepspeed_experts[index2].bias.half() * gate2 if self.bias_merge else None,
)
def top2_expert_forward(self, x, indices, gates, mode=None, **kwargs):
# caption eval mode
if comm._CAPTION_GEN_MODE and x.shape[1] == 1:
#
return self.mergelayer(x,
indices[0][1], indices[1][1],
gates[0][1], gates[1][1], mode=mode)
# unimodal
if indices[0].size(0) == 1:
x = self.mergelayer(x, indices[0][0], indices[1][0], gates[0][0], gates[1][0], mode=mode)
elif indices[0].size(0) == 2:
data1_length = kwargs['sample_info']['data_cum_length'][1]
if mode == 'kv' and kwargs['sample_info'].get('pe_length', 0) > 0:
# may have prompt embedding for kv embedding
data1_length += kwargs['sample_info'].get('pe_length', 0)
x = torch.cat([
self.mergelayer(x[:, :data1_length, :], indices[0][0], indices[1][0], gates[0][0], gates[1][0], mode=mode),
self.mergelayer(x[:, data1_length:, :], indices[0][1], indices[1][1], gates[0][1], gates[1][1], mode=mode)
],
dim=1)
else:
raise NotImplementedError('only support one or two modality')
return x
def forward(self, hidden_states, top_indices=None, gates=None, **kwargs):
# top1
if len(top_indices) == 1:
out = self.top1_expert_forward(hidden_states, top_indices[0], gates[0], **kwargs)
# top2
elif len(top_indices) == 2:
out = self.top2_expert_forward(hidden_states, top_indices, gates, **kwargs)
else:
raise NotImplementedError("only support top1 and top2 ")
assert out.shape[1] == hidden_states.shape[1]
return out
|