from functools import partial from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint as ckpt from wenet.transformer.attention import T_CACHE from wenet.transformer.encoder_layer import TransformerEncoderLayer from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES, WENET_ATTENTION_CLASSES, WENET_EMB_CLASSES, WENET_MLP_CLASSES, WENET_NORM_CLASSES) from wenet.utils.common import mask_to_bias class DecoderOnly(torch.nn.Module): def __init__( self, n_kv_head: int, head_dim: int, hidden_size: int, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, normalize_before: bool = True, query_bias: bool = False, key_bias: bool = False, value_bias: bool = False, mlp_bias: bool = False, activation_type: str = "gelu", gelu_approximate: Union[str, None] = None, max_position_embeding: int = 8192, mlp_type: str = 'gated', layer_norm_type: str = 'rms_norm', norm_eps: float = 1e-5, rms_norm_offset: bool = True, selfattention_layer_type: str = "rope_abs_selfattn", use_sdpa: bool = False, gradient_checkpointing: bool = False, rope_theta: float = 10000.0, rope_style: str = 'google', scale_embed: bool = True, ) -> None: super().__init__() assert selfattention_layer_type in ['rope_abs_selfattn'] self.pos_enc = WENET_EMB_CLASSES["rope_pos"]( hidden_size, head_dim, max_len=max_position_embeding, dropout_rate=positional_dropout_rate, rope_theta=rope_theta, scale=scale_embed) if activation_type == "gelu" and gelu_approximate is not None: activation = WENET_ACTIVATION_CLASSES['gelu']( approximate=gelu_approximate) else: activation = WENET_ACTIVATION_CLASSES[activation_type]() mlp_class = WENET_MLP_CLASSES[mlp_type] self.num_blocks = num_blocks # TODO: support lora & refactor lora self.decoders = torch.nn.ModuleList([ TransformerEncoderLayer( hidden_size, WENET_ATTENTION_CLASSES[selfattention_layer_type]( attention_heads, hidden_size, attention_dropout_rate, query_bias, key_bias, value_bias, use_sdpa, n_kv_head, head_dim, style=rope_style), mlp_class(hidden_size, linear_units, dropout_rate, activation, mlp_bias), dropout_rate, normalize_before, layer_norm_type=layer_norm_type, norm_eps=norm_eps, rms_norm_offset=rms_norm_offset, ) for _ in range(self.num_blocks) ]) self.pre_norm = normalize_before self.final_norm: Optional[torch.nn.Module] = None if self.pre_norm: norm_class = WENET_NORM_CLASSES[layer_norm_type] if layer_norm_type == "rms_norm": norm_class = partial( norm_class, add_unit_offset=rms_norm_offset, ) self.final_norm = norm_class(hidden_size, eps=norm_eps) self.n_kv_head = n_kv_head self.head_dim = head_dim self._hidden_size = hidden_size self.use_sdpa = use_sdpa self.gradient_checkpointing = gradient_checkpointing def forward( self, input: torch.Tensor, att_mask: torch.Tensor, input_position: Union[int, torch.Tensor] = 0, kv_caches: Optional[List[T_CACHE]] = None, ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: xs, pos_emb = self.pos_enc(input, offset=input_position) if self.use_sdpa: att_mask = mask_to_bias(att_mask, xs.dtype) if self.gradient_checkpointing and self.training: xs = self.forward_layers_checkpointed(xs, att_mask, pos_emb) else: xs, kv_caches = self.forward_layers(xs, att_mask, pos_emb, kv_caches) if self.pre_norm and self.final_norm is not None: xs = self.final_norm(xs) return xs, kv_caches def forward_layers( self, xs: torch.Tensor, att_mask: torch.Tensor, pos_emb: torch.Tensor, kv_caches: Optional[List[T_CACHE]] = None, ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: if self.training: for (i, layer) in enumerate(self.decoders): xs, _, _, _ = layer(xs, att_mask, pos_emb) new_kv_caches = kv_caches else: assert kv_caches is not None new_kv_caches = [] for (i, layer) in enumerate(self.decoders): xs, _, new_kv_cache, _ = layer(xs, att_mask, pos_emb, att_cache=(kv_caches[i][0], kv_caches[i][1])) new_kv_caches.append(new_kv_cache) return xs, new_kv_caches @torch.jit.ignore(drop=True) def forward_layers_checkpointed(self, xs: torch.Tensor, att_mask: torch.Tensor, pos_emb: torch.Tensor) -> torch.Tensor: for layer in self.decoders: xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, att_mask, pos_emb) return xs @property def hidden_size(self): return self._hidden_size