Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,142 Bytes
568e264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint as ckpt
from wenet.transformer.attention import T_CACHE
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES,
WENET_ATTENTION_CLASSES,
WENET_EMB_CLASSES, WENET_MLP_CLASSES,
WENET_NORM_CLASSES)
from wenet.utils.common import mask_to_bias
class DecoderOnly(torch.nn.Module):
def __init__(
self,
n_kv_head: int,
head_dim: int,
hidden_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
normalize_before: bool = True,
query_bias: bool = False,
key_bias: bool = False,
value_bias: bool = False,
mlp_bias: bool = False,
activation_type: str = "gelu",
gelu_approximate: Union[str, None] = None,
max_position_embeding: int = 8192,
mlp_type: str = 'gated',
layer_norm_type: str = 'rms_norm',
norm_eps: float = 1e-5,
rms_norm_offset: bool = True,
selfattention_layer_type: str = "rope_abs_selfattn",
use_sdpa: bool = False,
gradient_checkpointing: bool = False,
rope_theta: float = 10000.0,
rope_style: str = 'google',
scale_embed: bool = True,
) -> None:
super().__init__()
assert selfattention_layer_type in ['rope_abs_selfattn']
self.pos_enc = WENET_EMB_CLASSES["rope_pos"](
hidden_size,
head_dim,
max_len=max_position_embeding,
dropout_rate=positional_dropout_rate,
rope_theta=rope_theta,
scale=scale_embed)
if activation_type == "gelu" and gelu_approximate is not None:
activation = WENET_ACTIVATION_CLASSES['gelu'](
approximate=gelu_approximate)
else:
activation = WENET_ACTIVATION_CLASSES[activation_type]()
mlp_class = WENET_MLP_CLASSES[mlp_type]
self.num_blocks = num_blocks
# TODO: support lora & refactor lora
self.decoders = torch.nn.ModuleList([
TransformerEncoderLayer(
hidden_size,
WENET_ATTENTION_CLASSES[selfattention_layer_type](
attention_heads,
hidden_size,
attention_dropout_rate,
query_bias,
key_bias,
value_bias,
use_sdpa,
n_kv_head,
head_dim,
style=rope_style),
mlp_class(hidden_size, linear_units, dropout_rate, activation,
mlp_bias),
dropout_rate,
normalize_before,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
rms_norm_offset=rms_norm_offset,
) for _ in range(self.num_blocks)
])
self.pre_norm = normalize_before
self.final_norm: Optional[torch.nn.Module] = None
if self.pre_norm:
norm_class = WENET_NORM_CLASSES[layer_norm_type]
if layer_norm_type == "rms_norm":
norm_class = partial(
norm_class,
add_unit_offset=rms_norm_offset,
)
self.final_norm = norm_class(hidden_size, eps=norm_eps)
self.n_kv_head = n_kv_head
self.head_dim = head_dim
self._hidden_size = hidden_size
self.use_sdpa = use_sdpa
self.gradient_checkpointing = gradient_checkpointing
def forward(
self,
input: torch.Tensor,
att_mask: torch.Tensor,
input_position: Union[int, torch.Tensor] = 0,
kv_caches: Optional[List[T_CACHE]] = None,
) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]:
xs, pos_emb = self.pos_enc(input, offset=input_position)
if self.use_sdpa:
att_mask = mask_to_bias(att_mask, xs.dtype)
if self.gradient_checkpointing and self.training:
xs = self.forward_layers_checkpointed(xs, att_mask, pos_emb)
else:
xs, kv_caches = self.forward_layers(xs, att_mask, pos_emb,
kv_caches)
if self.pre_norm and self.final_norm is not None:
xs = self.final_norm(xs)
return xs, kv_caches
def forward_layers(
self,
xs: torch.Tensor,
att_mask: torch.Tensor,
pos_emb: torch.Tensor,
kv_caches: Optional[List[T_CACHE]] = None,
) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]:
if self.training:
for (i, layer) in enumerate(self.decoders):
xs, _, _, _ = layer(xs, att_mask, pos_emb)
new_kv_caches = kv_caches
else:
assert kv_caches is not None
new_kv_caches = []
for (i, layer) in enumerate(self.decoders):
xs, _, new_kv_cache, _ = layer(xs,
att_mask,
pos_emb,
att_cache=(kv_caches[i][0],
kv_caches[i][1]))
new_kv_caches.append(new_kv_cache)
return xs, new_kv_caches
@torch.jit.ignore(drop=True)
def forward_layers_checkpointed(self, xs: torch.Tensor,
att_mask: torch.Tensor,
pos_emb: torch.Tensor) -> torch.Tensor:
for layer in self.decoders:
xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, att_mask,
pos_emb)
return xs
@property
def hidden_size(self):
return self._hidden_size
|