Spaces:
Runtime error
Runtime error
File size: 4,389 Bytes
8b33290 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------
import math
import torch.nn as nn
import torch
import contextlib
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from fairseq.models.transformer import Linear #,LayerNorm
from fairseq.modules import (
PositionalEmbedding,
FairseqDropout,
LayerNorm
)
class TextDecoderPrenet(nn.Module):
"""
Args:
in_channels (int): the number of input channels
mid_channels (int): the number of intermediate channels
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
def __init__(self, embed_tokens, args):
super(TextDecoderPrenet, self).__init__()
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.decoder_layerdrop = args.decoder_layerdrop
self.num_updates = 0
input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim
self.embed_dim = embed_dim
self.output_embed_dim = args.decoder_output_dim
self.padding_idx = embed_tokens.padding_idx
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
if not args.adaptive_input and args.quant_noise_pq > 0:
self.quant_noise = apply_quant_noise_(
nn.Linear(embed_dim, embed_dim, bias=False),
args.quant_noise_pq,
args.quant_noise_pq_block_size,
)
else:
self.quant_noise = None
self.project_in_dim = (
Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
self.embed_positions = (
PositionalEmbedding(
args.max_text_positions,
embed_dim,
self.padding_idx,
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
export = getattr(args, "export", False)
if getattr(args, "layernorm_embedding", False):
self.layernorm_embedding = LayerNorm(embed_dim, export=export)
else:
self.layernorm_embedding = None
self.freeze_decoder_updates = args.freeze_decoder_updates
def forward(self, prev_output_tokens, incremental_state=None):
ft = self.freeze_decoder_updates <= self.num_updates
with torch.no_grad() if not ft else contextlib.ExitStack():
return self._forward(prev_output_tokens, incremental_state)
def _forward(self, prev_output_tokens, incremental_state=None):
if prev_output_tokens.eq(self.padding_idx).any():
x_mask = prev_output_tokens.eq(self.padding_idx)
else:
x_mask = None
# embed positions
positions = None
if self.embed_positions is not None:
positions = self.embed_positions(
prev_output_tokens, incremental_state=incremental_state
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.quant_noise is not None:
x = self.quant_noise(x)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
return x, x_mask, incremental_state
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
self.num_updates = num_updates
|