.
+ kernel_size (int): The convolution kernel size of the middle layer, .
+ padding (int): Padding value of the convolution in the middle layer.
+ dilation (int, optional): Dilation value of the convolution in the middle layer.
+ no_redisual (bool, optional): Disable residual block/output.
+
+ Note:
+ This implementation corresponds to the "non-causal" setting in the paper.
+ """
+
+ def __init__(
+ self,
+ io_channels: int,
+ hidden_channels: int,
+ kernel_size: int,
+ padding: int,
+ dilation: int = 1,
+ no_residual: bool = False,
+ ):
+ super().__init__()
+
+ self.conv_layers = torch.nn.Sequential(
+ torch.nn.Conv1d(in_channels=io_channels, out_channels=hidden_channels, kernel_size=1),
+ torch.nn.PReLU(),
+ torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
+ torch.nn.Conv1d(
+ in_channels=hidden_channels,
+ out_channels=hidden_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ dilation=dilation,
+ groups=hidden_channels,
+ ),
+ torch.nn.PReLU(),
+ torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
+ )
+
+ self.res_out = (
+ None
+ if no_residual
+ else torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
+ )
+ self.skip_out = torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
+
+ def forward(self, input: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
+ feature = self.conv_layers(input)
+ if self.res_out is None:
+ residual = None
+ else:
+ residual = self.res_out(feature)
+ skip_out = self.skip_out(feature)
+ return residual, skip_out
+
+
+class MaskGenerator(torch.nn.Module):
+ """TCN (Temporal Convolution Network) Separation Module
+
+ Generates masks for separation.
+
+ Args:
+ input_dim (int): Input feature dimension, .
+ num_sources (int): The number of sources to separate.
+ kernel_size (int): The convolution kernel size of conv blocks, .
+ num_featrs (int): Input/output feature dimenstion of conv blocks, .
+ num_hidden (int): Intermediate feature dimention of conv blocks,
+ num_layers (int): The number of conv blocks in one stack, .
+ num_stacks (int): The number of conv block stacks, .
+ msk_activate (str): The activation function of the mask output.
+
+ Note:
+ This implementation corresponds to the "non-causal" setting in the paper.
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ num_sources: int,
+ kernel_size: int,
+ num_feats: int,
+ num_hidden: int,
+ num_layers: int,
+ num_stacks: int,
+ msk_activate: str,
+ ):
+ super().__init__()
+
+ self.input_dim = input_dim
+ self.num_sources = num_sources
+
+ self.input_norm = torch.nn.GroupNorm(num_groups=1, num_channels=input_dim, eps=1e-8)
+ self.input_conv = torch.nn.Conv1d(in_channels=input_dim, out_channels=num_feats, kernel_size=1)
+
+ self.receptive_field = 0
+ self.conv_layers = torch.nn.ModuleList([])
+ for s in range(num_stacks):
+ for l in range(num_layers):
+ multi = 2**l
+ self.conv_layers.append(
+ ConvBlock(
+ io_channels=num_feats,
+ hidden_channels=num_hidden,
+ kernel_size=kernel_size,
+ dilation=multi,
+ padding=multi,
+ # The last ConvBlock does not need residual
+ no_residual=(l == (num_layers - 1) and s == (num_stacks - 1)),
+ )
+ )
+ self.receptive_field += kernel_size if s == 0 and l == 0 else (kernel_size - 1) * multi
+ self.output_prelu = torch.nn.PReLU()
+ self.output_conv = torch.nn.Conv1d(
+ in_channels=num_feats,
+ out_channels=input_dim * num_sources,
+ kernel_size=1,
+ )
+ if msk_activate == "sigmoid":
+ self.mask_activate = torch.nn.Sigmoid()
+ elif msk_activate == "relu":
+ self.mask_activate = torch.nn.ReLU()
+ else:
+ raise ValueError(f"Unsupported activation {msk_activate}")
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ """Generate separation mask.
+
+ Args:
+ input (torch.Tensor): 3D Tensor with shape [batch, features, frames]
+
+ Returns:
+ Tensor: shape [batch, num_sources, features, frames]
+ """
+ batch_size = input.shape[0]
+ feats = self.input_norm(input)
+ feats = self.input_conv(feats)
+ output = 0.0
+ for layer in self.conv_layers:
+ residual, skip = layer(feats)
+ if residual is not None: # the last conv layer does not produce residual
+ feats = feats + residual
+ output = output + skip
+ output = self.output_prelu(output)
+ output = self.output_conv(output)
+ output = self.mask_activate(output)
+ return output.view(batch_size, self.num_sources, self.input_dim, -1)
+
+
+class ConvTasNet(torch.nn.Module):
+ """Conv-TasNet architecture introduced in
+ *Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation*
+ :cite:`Luo_2019`.
+
+ Note:
+ This implementation corresponds to the "non-causal" setting in the paper.
+
+ See Also:
+ * :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.
+
+ Args:
+ num_sources (int, optional): The number of sources to split.
+ enc_kernel_size (int, optional): The convolution kernel size of the encoder/decoder, .
+ enc_num_feats (int, optional): The feature dimensions passed to mask generator, .
+ msk_kernel_size (int, optional): The convolution kernel size of the mask generator, .
+ msk_num_feats (int, optional): The input/output feature dimension of conv block in the mask generator, .
+ msk_num_hidden_feats (int, optional): The internal feature dimension of conv block of the mask generator, .
+ msk_num_layers (int, optional): The number of layers in one conv block of the mask generator, .
+ msk_num_stacks (int, optional): The numbr of conv blocks of the mask generator, .
+ msk_activate (str, optional): The activation function of the mask output (Default: ``sigmoid``).
+ """
+
+ def __init__(
+ self,
+ num_sources: int = 2,
+ # encoder/decoder parameters
+ enc_kernel_size: int = 16,
+ enc_num_feats: int = 512,
+ # mask generator parameters
+ msk_kernel_size: int = 3,
+ msk_num_feats: int = 128,
+ msk_num_hidden_feats: int = 512,
+ msk_num_layers: int = 8,
+ msk_num_stacks: int = 3,
+ msk_activate: str = "sigmoid",
+ ):
+ super().__init__()
+
+ self.num_sources = num_sources
+ self.enc_num_feats = enc_num_feats
+ self.enc_kernel_size = enc_kernel_size
+ self.enc_stride = enc_kernel_size // 2
+
+ self.encoder = torch.nn.Conv1d(
+ in_channels=1,
+ out_channels=enc_num_feats,
+ kernel_size=enc_kernel_size,
+ stride=self.enc_stride,
+ padding=self.enc_stride,
+ bias=False,
+ )
+ self.mask_generator = MaskGenerator(
+ input_dim=enc_num_feats,
+ num_sources=num_sources,
+ kernel_size=msk_kernel_size,
+ num_feats=msk_num_feats,
+ num_hidden=msk_num_hidden_feats,
+ num_layers=msk_num_layers,
+ num_stacks=msk_num_stacks,
+ msk_activate=msk_activate,
+ )
+ self.decoder = torch.nn.ConvTranspose1d(
+ in_channels=enc_num_feats,
+ out_channels=1,
+ kernel_size=enc_kernel_size,
+ stride=self.enc_stride,
+ padding=self.enc_stride,
+ bias=False,
+ )
+
+ def _align_num_frames_with_strides(self, input: torch.Tensor) -> Tuple[torch.Tensor, int]:
+ """Pad input Tensor so that the end of the input tensor corresponds with
+
+ 1. (if kernel size is odd) the center of the last convolution kernel
+ or 2. (if kernel size is even) the end of the first half of the last convolution kernel
+
+ Assumption:
+ The resulting Tensor will be padded with the size of stride (== kernel_width // 2)
+ on the both ends in Conv1D
+
+ |<--- k_1 --->|
+ | | |<-- k_n-1 -->|
+ | | | |<--- k_n --->|
+ | | | | |
+ | | | | |
+ | v v v |
+ |<---->|<--- input signal --->|<--->|<---->|
+ stride PAD stride
+
+ Args:
+ input (torch.Tensor): 3D Tensor with shape (batch_size, channels==1, frames)
+
+ Returns:
+ Tensor: Padded Tensor
+ int: Number of paddings performed
+ """
+ batch_size, num_channels, num_frames = input.shape
+ is_odd = self.enc_kernel_size % 2
+ num_strides = (num_frames - is_odd) // self.enc_stride
+ num_remainings = num_frames - (is_odd + num_strides * self.enc_stride)
+ if num_remainings == 0:
+ return input, 0
+
+ num_paddings = self.enc_stride - num_remainings
+ pad = torch.zeros(
+ batch_size,
+ num_channels,
+ num_paddings,
+ dtype=input.dtype,
+ device=input.device,
+ )
+ return torch.cat([input, pad], 2), num_paddings
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ """Perform source separation. Generate audio source waveforms.
+
+ Args:
+ input (torch.Tensor): 3D Tensor with shape [batch, channel==1, frames]
+
+ Returns:
+ Tensor: 3D Tensor with shape [batch, channel==num_sources, frames]
+ """
+ if input.ndim != 3 or input.shape[1] != 1:
+ raise ValueError(f"Expected 3D tensor (batch, channel==1, frames). Found: {input.shape}")
+
+ # B: batch size
+ # L: input frame length
+ # L': padded input frame length
+ # F: feature dimension
+ # M: feature frame length
+ # S: number of sources
+
+ padded, num_pads = self._align_num_frames_with_strides(input) # B, 1, L'
+ batch_size, num_padded_frames = padded.shape[0], padded.shape[2]
+ feats = self.encoder(padded) # B, F, M
+ masked = self.mask_generator(feats) * feats.unsqueeze(1) # B, S, F, M
+ masked = masked.view(batch_size * self.num_sources, self.enc_num_feats, -1) # B*S, F, M
+ decoded = self.decoder(masked) # B*S, 1, L'
+ output = decoded.view(batch_size, self.num_sources, num_padded_frames) # B, S, L'
+ if num_pads > 0:
+ output = output[..., :-num_pads] # B, S, L
+ return output
+
+
+def conv_tasnet_base(num_sources: int = 2) -> ConvTasNet:
+ r"""Builds non-causal version of :class:`~torchaudio.models.ConvTasNet`.
+
+ The parameter settings follow the ones with the highest Si-SNR metirc score in the paper,
+ except the mask activation function is changed from "sigmoid" to "relu" for performance improvement.
+
+ Args:
+ num_sources (int, optional): Number of sources in the output.
+ (Default: 2)
+ Returns:
+ ConvTasNet:
+ ConvTasNet model.
+ """
+ return ConvTasNet(
+ num_sources=num_sources,
+ enc_kernel_size=16,
+ enc_num_feats=512,
+ msk_kernel_size=3,
+ msk_num_feats=128,
+ msk_num_hidden_feats=512,
+ msk_num_layers=8,
+ msk_num_stacks=3,
+ msk_activate="relu",
+ )
diff --git a/MLPY/Lib/site-packages/torchaudio/models/decoder/__init__.py b/MLPY/Lib/site-packages/torchaudio/models/decoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b13d3b9e3567347ab494fddd3ee2b0106fec22a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/decoder/__init__.py
@@ -0,0 +1,46 @@
+_CTC_DECODERS = [
+ "CTCHypothesis",
+ "CTCDecoder",
+ "CTCDecoderLM",
+ "CTCDecoderLMState",
+ "ctc_decoder",
+ "download_pretrained_files",
+]
+_CUDA_CTC_DECODERS = [
+ "CUCTCDecoder",
+ "CUCTCHypothesis",
+ "cuda_ctc_decoder",
+]
+
+
+def __getattr__(name: str):
+ if name in _CTC_DECODERS:
+ try:
+ from . import _ctc_decoder
+ except Exception as err:
+ raise RuntimeError(
+ "CTC Decoder suit requires flashlight-text package and optionally KenLM. Please install them."
+ ) from err
+
+ item = getattr(_ctc_decoder, name)
+ globals()[name] = item
+ return item
+ elif name in _CUDA_CTC_DECODERS:
+ try:
+ from . import _cuda_ctc_decoder
+ except AttributeError as err:
+ raise RuntimeError(
+ "To use CUCTC decoder, please set BUILD_CUDA_CTC_DECODER=1 when building from source."
+ ) from err
+
+ item = getattr(_cuda_ctc_decoder, name)
+ globals()[name] = item
+ return item
+ raise AttributeError(f"module {__name__} has no attribute {name}")
+
+
+def __dir__():
+ return sorted(__all__)
+
+
+__all__ = _CTC_DECODERS + _CUDA_CTC_DECODERS
diff --git a/MLPY/Lib/site-packages/torchaudio/models/decoder/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/models/decoder/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07527eb57d9705989d8282a8b9ca30caa1bfb1f2
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/models/decoder/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/models/decoder/__pycache__/_ctc_decoder.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/models/decoder/__pycache__/_ctc_decoder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20988170f7b1b08f334b5f8dbc9b061b4d186244
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/models/decoder/__pycache__/_ctc_decoder.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/models/decoder/__pycache__/_cuda_ctc_decoder.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/models/decoder/__pycache__/_cuda_ctc_decoder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7702b880abcbc52663c8dccf555e1b5911b3bd42
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/models/decoder/__pycache__/_cuda_ctc_decoder.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/models/decoder/_ctc_decoder.py b/MLPY/Lib/site-packages/torchaudio/models/decoder/_ctc_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c379781f64d54a3b826d953d6cfd4de22051695f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/decoder/_ctc_decoder.py
@@ -0,0 +1,568 @@
+from __future__ import annotations
+
+import itertools as it
+
+from abc import abstractmethod
+from collections import namedtuple
+from typing import Dict, List, NamedTuple, Optional, Tuple, Union
+
+import torch
+
+from flashlight.lib.text.decoder import (
+ CriterionType as _CriterionType,
+ LexiconDecoder as _LexiconDecoder,
+ LexiconDecoderOptions as _LexiconDecoderOptions,
+ LexiconFreeDecoder as _LexiconFreeDecoder,
+ LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
+ LM as _LM,
+ LMState as _LMState,
+ SmearingMode as _SmearingMode,
+ Trie as _Trie,
+ ZeroLM as _ZeroLM,
+)
+from flashlight.lib.text.dictionary import (
+ create_word_dict as _create_word_dict,
+ Dictionary as _Dictionary,
+ load_words as _load_words,
+)
+from torchaudio.utils import download_asset
+
+try:
+ from flashlight.lib.text.decoder.kenlm import KenLM as _KenLM
+except Exception:
+ try:
+ from flashlight.lib.text.decoder import KenLM as _KenLM
+ except Exception:
+ _KenLM = None
+
+__all__ = [
+ "CTCHypothesis",
+ "CTCDecoder",
+ "CTCDecoderLM",
+ "CTCDecoderLMState",
+ "ctc_decoder",
+ "download_pretrained_files",
+]
+
+_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])
+
+
+def _construct_trie(tokens_dict, word_dict, lexicon, lm, silence):
+ vocab_size = tokens_dict.index_size()
+ trie = _Trie(vocab_size, silence)
+ start_state = lm.start(False)
+
+ for word, spellings in lexicon.items():
+ word_idx = word_dict.get_index(word)
+ _, score = lm.score(start_state, word_idx)
+ for spelling in spellings:
+ spelling_idx = [tokens_dict.get_index(token) for token in spelling]
+ trie.insert(spelling_idx, word_idx, score)
+ trie.smear(_SmearingMode.MAX)
+ return trie
+
+
+def _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word):
+ word_dict = None
+ if lm_dict is not None:
+ word_dict = _Dictionary(lm_dict)
+
+ if lexicon and word_dict is None:
+ word_dict = _create_word_dict(lexicon)
+ elif not lexicon and word_dict is None and type(lm) == str:
+ d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())}
+ d[unk_word] = [[unk_word]]
+ word_dict = _create_word_dict(d)
+
+ return word_dict
+
+
+class CTCHypothesis(NamedTuple):
+ r"""Represents hypothesis generated by CTC beam search decoder :class:`CTCDecoder`."""
+ tokens: torch.LongTensor
+ """Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence"""
+
+ words: List[str]
+ """List of predicted words.
+
+ Note:
+ This attribute is only applicable if a lexicon is provided to the decoder. If
+ decoding without a lexicon, it will be blank. Please refer to :attr:`tokens` and
+ :func:`~torchaudio.models.decoder.CTCDecoder.idxs_to_tokens` instead.
+ """
+
+ score: float
+ """Score corresponding to hypothesis"""
+
+ timesteps: torch.IntTensor
+ """Timesteps corresponding to the tokens. Shape `(L, )`, where `L` is the length of the output sequence"""
+
+
+class CTCDecoderLMState(_LMState):
+ """Language model state."""
+
+ @property
+ def children(self) -> Dict[int, CTCDecoderLMState]:
+ """Map of indices to LM states"""
+ return super().children
+
+ def child(self, usr_index: int) -> CTCDecoderLMState:
+ """Returns child corresponding to usr_index, or creates and returns a new state if input index
+ is not found.
+
+ Args:
+ usr_index (int): index corresponding to child state
+
+ Returns:
+ CTCDecoderLMState: child state corresponding to usr_index
+ """
+ return super().child(usr_index)
+
+ def compare(self, state: CTCDecoderLMState) -> CTCDecoderLMState:
+ """Compare two language model states.
+
+ Args:
+ state (CTCDecoderLMState): LM state to compare against
+
+ Returns:
+ int: 0 if the states are the same, -1 if self is less, +1 if self is greater.
+ """
+ pass
+
+
+class CTCDecoderLM(_LM):
+ """Language model base class for creating custom language models to use with the decoder."""
+
+ @abstractmethod
+ def start(self, start_with_nothing: bool) -> CTCDecoderLMState:
+ """Initialize or reset the language model.
+
+ Args:
+ start_with_nothing (bool): whether or not to start sentence with sil token.
+
+ Returns:
+ CTCDecoderLMState: starting state
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def score(self, state: CTCDecoderLMState, usr_token_idx: int) -> Tuple[CTCDecoderLMState, float]:
+ """Evaluate the language model based on the current LM state and new word.
+
+ Args:
+ state (CTCDecoderLMState): current LM state
+ usr_token_idx (int): index of the word
+
+ Returns:
+ (CTCDecoderLMState, float)
+ CTCDecoderLMState:
+ new LM state
+ float:
+ score
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def finish(self, state: CTCDecoderLMState) -> Tuple[CTCDecoderLMState, float]:
+ """Evaluate end for language model based on current LM state.
+
+ Args:
+ state (CTCDecoderLMState): current LM state
+
+ Returns:
+ (CTCDecoderLMState, float)
+ CTCDecoderLMState:
+ new LM state
+ float:
+ score
+ """
+ raise NotImplementedError
+
+
+class CTCDecoder:
+ """CTC beam search decoder from *Flashlight* :cite:`kahn2022flashlight`.
+
+ .. devices:: CPU
+
+ Note:
+ To build the decoder, please use the factory function :func:`ctc_decoder`.
+ """
+
+ def __init__(
+ self,
+ nbest: int,
+ lexicon: Optional[Dict],
+ word_dict: _Dictionary,
+ tokens_dict: _Dictionary,
+ lm: CTCDecoderLM,
+ decoder_options: Union[_LexiconDecoderOptions, _LexiconFreeDecoderOptions],
+ blank_token: str,
+ sil_token: str,
+ unk_word: str,
+ ) -> None:
+ """
+ Args:
+ nbest (int): number of best decodings to return
+ lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon-free decoder
+ word_dict (_Dictionary): dictionary of words
+ tokens_dict (_Dictionary): dictionary of tokens
+ lm (CTCDecoderLM): language model. If using a lexicon, only word level LMs are currently supported
+ decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions):
+ parameters used for beam search decoding
+ blank_token (str): token corresopnding to blank
+ sil_token (str): token corresponding to silence
+ unk_word (str): word corresponding to unknown
+ """
+
+ self.nbest = nbest
+ self.word_dict = word_dict
+ self.tokens_dict = tokens_dict
+ self.blank = self.tokens_dict.get_index(blank_token)
+ silence = self.tokens_dict.get_index(sil_token)
+ transitions = []
+
+ if lexicon:
+ trie = _construct_trie(tokens_dict, word_dict, lexicon, lm, silence)
+ unk_word = word_dict.get_index(unk_word)
+ token_lm = False # use word level LM
+
+ self.decoder = _LexiconDecoder(
+ decoder_options,
+ trie,
+ lm,
+ silence,
+ self.blank,
+ unk_word,
+ transitions,
+ token_lm,
+ )
+ else:
+ self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions)
+ # https://github.com/pytorch/audio/issues/3218
+ # If lm is passed like rvalue reference, the lm object gets garbage collected,
+ # and later call to the lm fails.
+ # This ensures that lm object is not deleted as long as the decoder is alive.
+ # https://github.com/pybind/pybind11/discussions/4013
+ self.lm = lm
+
+ def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
+ idxs = (g[0] for g in it.groupby(idxs))
+ idxs = filter(lambda x: x != self.blank, idxs)
+ return torch.LongTensor(list(idxs))
+
+ def _get_timesteps(self, idxs: torch.IntTensor) -> torch.IntTensor:
+ """Returns frame numbers corresponding to non-blank tokens."""
+
+ timesteps = []
+ for i, idx in enumerate(idxs):
+ if idx == self.blank:
+ continue
+ if i == 0 or idx != idxs[i - 1]:
+ timesteps.append(i)
+ return torch.IntTensor(timesteps)
+
+ def decode_begin(self):
+ """Initialize the internal state of the decoder.
+
+ See :py:meth:`decode_step` for the usage.
+
+ .. note::
+
+ This method is required only when performing online decoding.
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
+ """
+ self.decoder.decode_begin()
+
+ def decode_end(self):
+ """Finalize the internal state of the decoder.
+
+ See :py:meth:`decode_step` for the usage.
+
+ .. note::
+
+ This method is required only when performing online decoding.
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
+ """
+ self.decoder.decode_end()
+
+ def decode_step(self, emissions: torch.FloatTensor):
+ """Perform incremental decoding on top of the curent internal state.
+
+ .. note::
+
+ This method is required only when performing online decoding.
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
+
+ Args:
+ emissions (torch.FloatTensor): CPU tensor of shape `(frame, num_tokens)` storing sequences of
+ probability distribution over labels; output of acoustic model.
+
+ Example:
+ >>> decoder = torchaudio.models.decoder.ctc_decoder(...)
+ >>> decoder.decode_begin()
+ >>> decoder.decode_step(emission1)
+ >>> decoder.decode_step(emission2)
+ >>> decoder.decode_end()
+ >>> result = decoder.get_final_hypothesis()
+ """
+ if emissions.dtype != torch.float32:
+ raise ValueError("emissions must be float32.")
+
+ if not emissions.is_cpu:
+ raise RuntimeError("emissions must be a CPU tensor.")
+
+ if not emissions.is_contiguous():
+ raise RuntimeError("emissions must be contiguous.")
+
+ if emissions.ndim != 2:
+ raise RuntimeError(f"emissions must be 2D. Found {emissions.shape}")
+
+ T, N = emissions.size()
+ self.decoder.decode_step(emissions.data_ptr(), T, N)
+
+ def _to_hypo(self, results) -> List[CTCHypothesis]:
+ return [
+ CTCHypothesis(
+ tokens=self._get_tokens(result.tokens),
+ words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
+ score=result.score,
+ timesteps=self._get_timesteps(result.tokens),
+ )
+ for result in results
+ ]
+
+ def get_final_hypothesis(self) -> List[CTCHypothesis]:
+ """Get the final hypothesis
+
+ Returns:
+ List[CTCHypothesis]:
+ List of sorted best hypotheses.
+
+ .. note::
+
+ This method is required only when performing online decoding.
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
+ """
+ results = self.decoder.get_all_final_hypothesis()
+ return self._to_hypo(results[: self.nbest])
+
+ def __call__(
+ self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None
+ ) -> List[List[CTCHypothesis]]:
+ """
+ Performs batched offline decoding.
+
+ .. note::
+
+ This method performs offline decoding in one go. To perform incremental decoding,
+ please refer to :py:meth:`decode_step`.
+
+ Args:
+ emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
+ probability distribution over labels; output of acoustic model.
+ lengths (Tensor or None, optional): CPU tensor of shape `(batch, )` storing the valid length of
+ in time axis of the output Tensor in each batch.
+
+ Returns:
+ List[List[CTCHypothesis]]:
+ List of sorted best hypotheses for each audio sequence in the batch.
+ """
+
+ if emissions.dtype != torch.float32:
+ raise ValueError("emissions must be float32.")
+
+ if not emissions.is_cpu:
+ raise RuntimeError("emissions must be a CPU tensor.")
+
+ if not emissions.is_contiguous():
+ raise RuntimeError("emissions must be contiguous.")
+
+ if emissions.ndim != 3:
+ raise RuntimeError(f"emissions must be 3D. Found {emissions.shape}")
+
+ if lengths is not None and not lengths.is_cpu:
+ raise RuntimeError("lengths must be a CPU tensor.")
+
+ B, T, N = emissions.size()
+ if lengths is None:
+ lengths = torch.full((B,), T)
+
+ float_bytes = 4
+ hypos = []
+
+ for b in range(B):
+ emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
+ results = self.decoder.decode(emissions_ptr, lengths[b], N)
+ hypos.append(self._to_hypo(results[: self.nbest]))
+ return hypos
+
+ def idxs_to_tokens(self, idxs: torch.LongTensor) -> List:
+ """
+ Map raw token IDs into corresponding tokens
+
+ Args:
+ idxs (LongTensor): raw token IDs generated from decoder
+
+ Returns:
+ List: tokens corresponding to the input IDs
+ """
+ return [self.tokens_dict.get_entry(idx.item()) for idx in idxs]
+
+
+def ctc_decoder(
+ lexicon: Optional[str],
+ tokens: Union[str, List[str]],
+ lm: Union[str, CTCDecoderLM] = None,
+ lm_dict: Optional[str] = None,
+ nbest: int = 1,
+ beam_size: int = 50,
+ beam_size_token: Optional[int] = None,
+ beam_threshold: float = 50,
+ lm_weight: float = 2,
+ word_score: float = 0,
+ unk_score: float = float("-inf"),
+ sil_score: float = 0,
+ log_add: bool = False,
+ blank_token: str = "-",
+ sil_token: str = "|",
+ unk_word: str = "",
+) -> CTCDecoder:
+ """Builds an instance of :class:`CTCDecoder`.
+
+ Args:
+ lexicon (str or None): lexicon file containing the possible words and corresponding spellings.
+ Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free
+ decoding.
+ tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected
+ format is for tokens mapping to the same index to be on the same line
+ lm (str, CTCDecoderLM, or None, optional): either a path containing KenLM language model,
+ custom language model of type `CTCDecoderLM`, or `None` if not using a language model
+ lm_dict (str or None, optional): file consisting of the dictionary used for the LM, with a word
+ per line sorted by LM index. If decoding with a lexicon, entries in lm_dict must also occur
+ in the lexicon file. If `None`, dictionary for LM is constructed using the lexicon file.
+ (Default: None)
+ nbest (int, optional): number of best decodings to return (Default: 1)
+ beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50)
+ beam_size_token (int, optional): max number of tokens to consider at each decode step.
+ If `None`, it is set to the total number of tokens (Default: None)
+ beam_threshold (float, optional): threshold for pruning hypothesis (Default: 50)
+ lm_weight (float, optional): weight of language model (Default: 2)
+ word_score (float, optional): word insertion score (Default: 0)
+ unk_score (float, optional): unknown word insertion score (Default: -inf)
+ sil_score (float, optional): silence insertion score (Default: 0)
+ log_add (bool, optional): whether or not to use logadd when merging hypotheses (Default: False)
+ blank_token (str, optional): token corresponding to blank (Default: "-")
+ sil_token (str, optional): token corresponding to silence (Default: "|")
+ unk_word (str, optional): word corresponding to unknown (Default: "")
+
+ Returns:
+ CTCDecoder: decoder
+
+ Example
+ >>> decoder = ctc_decoder(
+ >>> lexicon="lexicon.txt",
+ >>> tokens="tokens.txt",
+ >>> lm="kenlm.bin",
+ >>> )
+ >>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses
+ """
+ if lm_dict is not None and type(lm_dict) is not str:
+ raise ValueError("lm_dict must be None or str type.")
+
+ tokens_dict = _Dictionary(tokens)
+
+ # decoder options
+ if lexicon:
+ lexicon = _load_words(lexicon)
+ decoder_options = _LexiconDecoderOptions(
+ beam_size=beam_size,
+ beam_size_token=beam_size_token or tokens_dict.index_size(),
+ beam_threshold=beam_threshold,
+ lm_weight=lm_weight,
+ word_score=word_score,
+ unk_score=unk_score,
+ sil_score=sil_score,
+ log_add=log_add,
+ criterion_type=_CriterionType.CTC,
+ )
+ else:
+ decoder_options = _LexiconFreeDecoderOptions(
+ beam_size=beam_size,
+ beam_size_token=beam_size_token or tokens_dict.index_size(),
+ beam_threshold=beam_threshold,
+ lm_weight=lm_weight,
+ sil_score=sil_score,
+ log_add=log_add,
+ criterion_type=_CriterionType.CTC,
+ )
+
+ # construct word dict and language model
+ word_dict = _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word)
+
+ if type(lm) == str:
+ if _KenLM is None:
+ raise RuntimeError(
+ "flashlight-text is installed, but KenLM is not installed. "
+ "Please refer to https://github.com/kpu/kenlm#python-module for how to install it."
+ )
+ lm = _KenLM(lm, word_dict)
+ elif lm is None:
+ lm = _ZeroLM()
+
+ return CTCDecoder(
+ nbest=nbest,
+ lexicon=lexicon,
+ word_dict=word_dict,
+ tokens_dict=tokens_dict,
+ lm=lm,
+ decoder_options=decoder_options,
+ blank_token=blank_token,
+ sil_token=sil_token,
+ unk_word=unk_word,
+ )
+
+
+def _get_filenames(model: str) -> _PretrainedFiles:
+ if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]:
+ raise ValueError(
+ f"{model} not supported. Must be one of ['librispeech-3-gram', 'librispeech-4-gram', 'librispeech']"
+ )
+
+ prefix = f"decoder-assets/{model}"
+ return _PretrainedFiles(
+ lexicon=f"{prefix}/lexicon.txt",
+ tokens=f"{prefix}/tokens.txt",
+ lm=f"{prefix}/lm.bin" if model != "librispeech" else None,
+ )
+
+
+def download_pretrained_files(model: str) -> _PretrainedFiles:
+ """
+ Retrieves pretrained data files used for :func:`ctc_decoder`.
+
+ Args:
+ model (str): pretrained language model to download.
+ Valid values are: ``"librispeech-3-gram"``, ``"librispeech-4-gram"`` and ``"librispeech"``.
+
+ Returns:
+ Object with the following attributes
+
+ * ``lm``: path corresponding to downloaded language model,
+ or ``None`` if the model is not associated with an lm
+ * ``lexicon``: path corresponding to downloaded lexicon file
+ * ``tokens``: path corresponding to downloaded tokens file
+ """
+
+ files = _get_filenames(model)
+ lexicon_file = download_asset(files.lexicon)
+ tokens_file = download_asset(files.tokens)
+ if files.lm is not None:
+ lm_file = download_asset(files.lm)
+ else:
+ lm_file = None
+
+ return _PretrainedFiles(
+ lexicon=lexicon_file,
+ tokens=tokens_file,
+ lm=lm_file,
+ )
diff --git a/MLPY/Lib/site-packages/torchaudio/models/decoder/_cuda_ctc_decoder.py b/MLPY/Lib/site-packages/torchaudio/models/decoder/_cuda_ctc_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..aebba02661cb25c5c14b42680fb77c7b5964e6e3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/decoder/_cuda_ctc_decoder.py
@@ -0,0 +1,187 @@
+from __future__ import annotations
+
+import math
+
+from typing import List, NamedTuple, Union
+
+import torch
+import torchaudio
+
+torchaudio._extension._load_lib("libctc_prefix_decoder")
+import torchaudio.lib.pybind11_prefixctc as cuctc
+
+
+__all__ = ["CUCTCHypothesis", "CUCTCDecoder", "cuda_ctc_decoder"]
+
+
+def _get_vocab_list(vocab_file):
+ vocab = []
+ with open(vocab_file, "r", encoding="utf-8") as f:
+ for line in f:
+ line = line.strip().split()
+ vocab.append(line[0])
+ return vocab
+
+
+class CUCTCHypothesis(NamedTuple):
+ r"""Represents hypothesis generated by CUCTC beam search decoder :class:`CUCTCDecoder`."""
+ tokens: List[int]
+ """Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence"""
+
+ words: List[str]
+ """List of predicted tokens. Algin with modeling unit.
+ """
+
+ score: float
+ """Score corresponding to hypothesis"""
+
+
+_DEFAULT_BLANK_SKIP_THREASHOLD = 0.95
+
+
+class CUCTCDecoder:
+ """CUDA CTC beam search decoder.
+
+ .. devices:: CUDA
+
+ Note:
+ To build the decoder, please use the factory function :func:`cuda_ctc_decoder`.
+ """
+
+ def __init__(
+ self,
+ vocab_list: List[str],
+ blank_id: int = 0,
+ beam_size: int = 10,
+ nbest: int = 1,
+ blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD,
+ cuda_stream: torch.cuda.streams.Stream = None,
+ ):
+ """
+ Args:
+ blank_id (int): token id corresopnding to blank, only support 0 for now. (Default: 0)
+ vocab_list (List[str]): list of vocabulary tokens
+ beam_size (int, optional): max number of hypos to hold after each decode step (Default: 10)
+ nbest (int): number of best decodings to return
+ blank_skip_threshold (float):
+ skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding.
+ (Default: 0.95).
+ cuda_stream (torch.cuda.streams.Stream): using assigned cuda stream (Default: using default stream)
+
+ """
+ if cuda_stream:
+ if not isinstance(cuda_stream, torch.cuda.streams.Stream):
+ raise AssertionError("cuda_stream must be torch.cuda.streams.Stream")
+ cuda_stream_ = cuda_stream.cuda_stream if cuda_stream else torch.cuda.current_stream().cuda_stream
+ self.internal_data = cuctc.prefixCTC_alloc(cuda_stream_)
+ self.memory = torch.empty(0, dtype=torch.int8, device=torch.device("cuda"))
+ if blank_id != 0:
+ raise AssertionError("blank_id must be 0")
+ self.blank_id = blank_id
+ self.vocab_list = vocab_list
+ self.space_id = 0
+ self.nbest = nbest
+ if not (blank_skip_threshold >= 0 and blank_skip_threshold <= 1):
+ raise AssertionError("blank_skip_threshold must be between 0 and 1")
+ self.blank_skip_threshold = math.log(blank_skip_threshold)
+ self.beam_size = min(beam_size, len(vocab_list)) # beam size must be smaller than vocab size
+
+ def __del__(self):
+ if cuctc is not None:
+ cuctc.prefixCTC_free(self.internal_data)
+
+ def __call__(self, log_prob: torch.Tensor, encoder_out_lens: torch.Tensor):
+ """
+ Args:
+ log_prob (torch.FloatTensor): GPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
+ probability distribution over labels; log_softmax(output of acoustic model).
+ lengths (dtype torch.int32): GPU tensor of shape `(batch, )` storing the valid length of
+ in time axis of the output Tensor in each batch.
+
+ Returns:
+ List[List[CUCTCHypothesis]]:
+ List of sorted best hypotheses for each audio sequence in the batch.
+ """
+ if not encoder_out_lens.dtype == torch.int32:
+ raise AssertionError("encoder_out_lens must be torch.int32")
+ if not log_prob.dtype == torch.float32:
+ raise AssertionError("log_prob must be torch.float32")
+ if not (log_prob.is_cuda and encoder_out_lens.is_cuda):
+ raise AssertionError("inputs must be cuda tensors")
+ if not (log_prob.is_contiguous() and encoder_out_lens.is_contiguous()):
+ raise AssertionError("input tensors must be contiguous")
+ required_size, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2(
+ self.internal_data,
+ self.memory.data_ptr(),
+ self.memory.size(0),
+ log_prob.data_ptr(),
+ encoder_out_lens.data_ptr(),
+ log_prob.size(),
+ log_prob.stride(),
+ self.beam_size,
+ self.blank_id,
+ self.space_id,
+ self.blank_skip_threshold,
+ )
+ if required_size > 0:
+ self.memory = torch.empty(required_size, dtype=torch.int8, device=log_prob.device).contiguous()
+ _, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2(
+ self.internal_data,
+ self.memory.data_ptr(),
+ self.memory.size(0),
+ log_prob.data_ptr(),
+ encoder_out_lens.data_ptr(),
+ log_prob.size(),
+ log_prob.stride(),
+ self.beam_size,
+ self.blank_id,
+ self.space_id,
+ self.blank_skip_threshold,
+ )
+ batch_size = len(score_hyps)
+ hypos = []
+ for i in range(batch_size):
+ hypos.append(
+ [
+ CUCTCHypothesis(
+ tokens=score_hyps[i][j][1],
+ words=[self.vocab_list[word_id] for word_id in score_hyps[i][j][1]],
+ score=score_hyps[i][j][0],
+ )
+ for j in range(self.nbest)
+ ]
+ )
+ return hypos
+
+
+def cuda_ctc_decoder(
+ tokens: Union[str, List[str]],
+ nbest: int = 1,
+ beam_size: int = 10,
+ blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD,
+) -> CUCTCDecoder:
+ """Builds an instance of :class:`CUCTCDecoder`.
+
+ Args:
+ tokens (str or List[str]): File or list containing valid tokens.
+ If using a file, the expected format is for tokens mapping to the same index to be on the same line
+ beam_size (int, optional): The maximum number of hypos to hold after each decode step (Default: 10)
+ nbest (int): The number of best decodings to return
+ blank_id (int): The token ID corresopnding to the blank symbol.
+ blank_skip_threshold (float): skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding
+ (Default: 0.95).
+
+ Returns:
+ CUCTCDecoder: decoder
+
+ Example
+ >>> decoder = cuda_ctc_decoder(
+ >>> vocab_file="tokens.txt",
+ >>> blank_skip_threshold=0.95,
+ >>> )
+ >>> results = decoder(log_probs, encoder_out_lens) # List of shape (B, nbest) of Hypotheses
+ """
+ if type(tokens) == str:
+ tokens = _get_vocab_list(tokens)
+
+ return CUCTCDecoder(vocab_list=tokens, beam_size=beam_size, nbest=nbest, blank_skip_threshold=blank_skip_threshold)
diff --git a/MLPY/Lib/site-packages/torchaudio/models/deepspeech.py b/MLPY/Lib/site-packages/torchaudio/models/deepspeech.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a6a0faa006a3fa6868ccb7e39e68118d8dbe277
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/deepspeech.py
@@ -0,0 +1,84 @@
+import torch
+
+__all__ = ["DeepSpeech"]
+
+
+class FullyConnected(torch.nn.Module):
+ """
+ Args:
+ n_feature: Number of input features
+ n_hidden: Internal hidden unit size.
+ """
+
+ def __init__(self, n_feature: int, n_hidden: int, dropout: float, relu_max_clip: int = 20) -> None:
+ super(FullyConnected, self).__init__()
+ self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True)
+ self.relu_max_clip = relu_max_clip
+ self.dropout = dropout
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.fc(x)
+ x = torch.nn.functional.relu(x)
+ x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip)
+ if self.dropout:
+ x = torch.nn.functional.dropout(x, self.dropout, self.training)
+ return x
+
+
+class DeepSpeech(torch.nn.Module):
+ """DeepSpeech architecture introduced in
+ *Deep Speech: Scaling up end-to-end speech recognition* :cite:`hannun2014deep`.
+
+ Args:
+ n_feature: Number of input features
+ n_hidden: Internal hidden unit size.
+ n_class: Number of output classes
+ """
+
+ def __init__(
+ self,
+ n_feature: int,
+ n_hidden: int = 2048,
+ n_class: int = 40,
+ dropout: float = 0.0,
+ ) -> None:
+ super(DeepSpeech, self).__init__()
+ self.n_hidden = n_hidden
+ self.fc1 = FullyConnected(n_feature, n_hidden, dropout)
+ self.fc2 = FullyConnected(n_hidden, n_hidden, dropout)
+ self.fc3 = FullyConnected(n_hidden, n_hidden, dropout)
+ self.bi_rnn = torch.nn.RNN(n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True)
+ self.fc4 = FullyConnected(n_hidden, n_hidden, dropout)
+ self.out = torch.nn.Linear(n_hidden, n_class)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): Tensor of dimension (batch, channel, time, feature).
+ Returns:
+ Tensor: Predictor tensor of dimension (batch, time, class).
+ """
+ # N x C x T x F
+ x = self.fc1(x)
+ # N x C x T x H
+ x = self.fc2(x)
+ # N x C x T x H
+ x = self.fc3(x)
+ # N x C x T x H
+ x = x.squeeze(1)
+ # N x T x H
+ x = x.transpose(0, 1)
+ # T x N x H
+ x, _ = self.bi_rnn(x)
+ # The fifth (non-recurrent) layer takes both the forward and backward units as inputs
+ x = x[:, :, : self.n_hidden] + x[:, :, self.n_hidden :]
+ # T x N x H
+ x = self.fc4(x)
+ # T x N x H
+ x = self.out(x)
+ # T x N x n_class
+ x = x.permute(1, 0, 2)
+ # N x T x n_class
+ x = torch.nn.functional.log_softmax(x, dim=2)
+ # N x T x n_class
+ return x
diff --git a/MLPY/Lib/site-packages/torchaudio/models/emformer.py b/MLPY/Lib/site-packages/torchaudio/models/emformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa678869c07126a9a5556d35f40ca3324b3fe6b4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/emformer.py
@@ -0,0 +1,884 @@
+import math
+from typing import List, Optional, Tuple
+
+import torch
+
+
+__all__ = ["Emformer"]
+
+
+def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
+ batch_size = lengths.shape[0]
+ max_length = int(torch.max(lengths).item())
+ padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
+ batch_size, max_length
+ ) >= lengths.unsqueeze(1)
+ return padding_mask
+
+
+def _gen_padding_mask(
+ utterance: torch.Tensor,
+ right_context: torch.Tensor,
+ summary: torch.Tensor,
+ lengths: torch.Tensor,
+ mems: torch.Tensor,
+ left_context_key: Optional[torch.Tensor] = None,
+) -> Optional[torch.Tensor]:
+ T = right_context.size(0) + utterance.size(0) + summary.size(0)
+ B = right_context.size(1)
+ if B == 1:
+ padding_mask = None
+ else:
+ right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
+ left_context_blocks_length = left_context_key.size(0) if left_context_key is not None else 0
+ klengths = lengths + mems.size(0) + right_context_blocks_length + left_context_blocks_length
+ padding_mask = _lengths_to_padding_mask(lengths=klengths)
+ return padding_mask
+
+
+def _get_activation_module(activation: str) -> torch.nn.Module:
+ if activation == "relu":
+ return torch.nn.ReLU()
+ elif activation == "gelu":
+ return torch.nn.GELU()
+ elif activation == "silu":
+ return torch.nn.SiLU()
+ else:
+ raise ValueError(f"Unsupported activation {activation}")
+
+
+def _get_weight_init_gains(weight_init_scale_strategy: Optional[str], num_layers: int) -> List[Optional[float]]:
+ if weight_init_scale_strategy is None:
+ return [None for _ in range(num_layers)]
+ elif weight_init_scale_strategy == "depthwise":
+ return [1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)]
+ elif weight_init_scale_strategy == "constant":
+ return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
+ else:
+ raise ValueError(f"Unsupported weight_init_scale_strategy value {weight_init_scale_strategy}")
+
+
+def _gen_attention_mask_block(
+ col_widths: List[int], col_mask: List[bool], num_rows: int, device: torch.device
+) -> torch.Tensor:
+ if len(col_widths) != len(col_mask):
+ raise ValueError("Length of col_widths must match that of col_mask")
+
+ mask_block = [
+ torch.ones(num_rows, col_width, device=device)
+ if is_ones_col
+ else torch.zeros(num_rows, col_width, device=device)
+ for col_width, is_ones_col in zip(col_widths, col_mask)
+ ]
+ return torch.cat(mask_block, dim=1)
+
+
+class _EmformerAttention(torch.nn.Module):
+ r"""Emformer layer attention module.
+
+ Args:
+ input_dim (int): input dimension.
+ num_heads (int): number of attention heads in each Emformer layer.
+ dropout (float, optional): dropout probability. (Default: 0.0)
+ weight_init_gain (float or None, optional): scale factor to apply when initializing
+ attention module parameters. (Default: ``None``)
+ tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
+ negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ weight_init_gain: Optional[float] = None,
+ tanh_on_mem: bool = False,
+ negative_inf: float = -1e8,
+ ):
+ super().__init__()
+
+ if input_dim % num_heads != 0:
+ raise ValueError(f"input_dim ({input_dim}) is not a multiple of num_heads ({num_heads}).")
+
+ self.input_dim = input_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.tanh_on_mem = tanh_on_mem
+ self.negative_inf = negative_inf
+
+ self.scaling = (self.input_dim // self.num_heads) ** -0.5
+
+ self.emb_to_key_value = torch.nn.Linear(input_dim, 2 * input_dim, bias=True)
+ self.emb_to_query = torch.nn.Linear(input_dim, input_dim, bias=True)
+ self.out_proj = torch.nn.Linear(input_dim, input_dim, bias=True)
+
+ if weight_init_gain:
+ torch.nn.init.xavier_uniform_(self.emb_to_key_value.weight, gain=weight_init_gain)
+ torch.nn.init.xavier_uniform_(self.emb_to_query.weight, gain=weight_init_gain)
+
+ def _gen_key_value(self, input: torch.Tensor, mems: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ T, _, _ = input.shape
+ summary_length = mems.size(0) + 1
+ right_ctx_utterance_block = input[: T - summary_length]
+ mems_right_ctx_utterance_block = torch.cat([mems, right_ctx_utterance_block])
+ key, value = self.emb_to_key_value(mems_right_ctx_utterance_block).chunk(chunks=2, dim=2)
+ return key, value
+
+ def _gen_attention_probs(
+ self,
+ attention_weights: torch.Tensor,
+ attention_mask: torch.Tensor,
+ padding_mask: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ attention_weights_float = attention_weights.float()
+ attention_weights_float = attention_weights_float.masked_fill(attention_mask.unsqueeze(0), self.negative_inf)
+ T = attention_weights.size(1)
+ B = attention_weights.size(0) // self.num_heads
+ if padding_mask is not None:
+ attention_weights_float = attention_weights_float.view(B, self.num_heads, T, -1)
+ attention_weights_float = attention_weights_float.masked_fill(
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf
+ )
+ attention_weights_float = attention_weights_float.view(B * self.num_heads, T, -1)
+ attention_probs = torch.nn.functional.softmax(attention_weights_float, dim=-1).type_as(attention_weights)
+ return torch.nn.functional.dropout(attention_probs, p=float(self.dropout), training=self.training)
+
+ def _forward_impl(
+ self,
+ utterance: torch.Tensor,
+ lengths: torch.Tensor,
+ right_context: torch.Tensor,
+ summary: torch.Tensor,
+ mems: torch.Tensor,
+ attention_mask: torch.Tensor,
+ left_context_key: Optional[torch.Tensor] = None,
+ left_context_val: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ B = utterance.size(1)
+ T = right_context.size(0) + utterance.size(0) + summary.size(0)
+
+ # Compute query with [right context, utterance, summary].
+ query = self.emb_to_query(torch.cat([right_context, utterance, summary]))
+
+ # Compute key and value with [mems, right context, utterance].
+ key, value = self.emb_to_key_value(torch.cat([mems, right_context, utterance])).chunk(chunks=2, dim=2)
+
+ if left_context_key is not None and left_context_val is not None:
+ right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
+ key = torch.cat(
+ [
+ key[: mems.size(0) + right_context_blocks_length],
+ left_context_key,
+ key[mems.size(0) + right_context_blocks_length :],
+ ],
+ )
+ value = torch.cat(
+ [
+ value[: mems.size(0) + right_context_blocks_length],
+ left_context_val,
+ value[mems.size(0) + right_context_blocks_length :],
+ ],
+ )
+
+ # Compute attention weights from query, key, and value.
+ reshaped_query, reshaped_key, reshaped_value = [
+ tensor.contiguous().view(-1, B * self.num_heads, self.input_dim // self.num_heads).transpose(0, 1)
+ for tensor in [query, key, value]
+ ]
+ attention_weights = torch.bmm(reshaped_query * self.scaling, reshaped_key.transpose(1, 2))
+
+ # Compute padding mask.
+ padding_mask = _gen_padding_mask(utterance, right_context, summary, lengths, mems, left_context_key)
+
+ # Compute attention probabilities.
+ attention_probs = self._gen_attention_probs(attention_weights, attention_mask, padding_mask)
+
+ # Compute attention.
+ attention = torch.bmm(attention_probs, reshaped_value)
+ if attention.shape != (
+ B * self.num_heads,
+ T,
+ self.input_dim // self.num_heads,
+ ):
+ raise AssertionError("Computed attention has incorrect dimensions")
+ attention = attention.transpose(0, 1).contiguous().view(T, B, self.input_dim)
+
+ # Apply output projection.
+ output_right_context_mems = self.out_proj(attention)
+
+ summary_length = summary.size(0)
+ output_right_context = output_right_context_mems[: T - summary_length]
+ output_mems = output_right_context_mems[T - summary_length :]
+ if self.tanh_on_mem:
+ output_mems = torch.tanh(output_mems)
+ else:
+ output_mems = torch.clamp(output_mems, min=-10, max=10)
+
+ return output_right_context, output_mems, key, value
+
+ def forward(
+ self,
+ utterance: torch.Tensor,
+ lengths: torch.Tensor,
+ right_context: torch.Tensor,
+ summary: torch.Tensor,
+ mems: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""Forward pass for training.
+
+ B: batch size;
+ D: feature dimension of each frame;
+ T: number of utterance frames;
+ R: number of right context frames;
+ S: number of summary elements;
+ M: number of memory elements.
+
+ Args:
+ utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``utterance``.
+ right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
+ summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
+ mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
+ attention_mask (torch.Tensor): attention mask for underlying attention module.
+
+ Returns:
+ (Tensor, Tensor):
+ Tensor
+ output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
+ Tensor
+ updated memory elements, with shape `(M, B, D)`.
+ """
+ output, output_mems, _, _ = self._forward_impl(utterance, lengths, right_context, summary, mems, attention_mask)
+ return output, output_mems[:-1]
+
+ @torch.jit.export
+ def infer(
+ self,
+ utterance: torch.Tensor,
+ lengths: torch.Tensor,
+ right_context: torch.Tensor,
+ summary: torch.Tensor,
+ mems: torch.Tensor,
+ left_context_key: torch.Tensor,
+ left_context_val: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ r"""Forward pass for inference.
+
+ B: batch size;
+ D: feature dimension of each frame;
+ T: number of utterance frames;
+ R: number of right context frames;
+ S: number of summary elements;
+ M: number of memory elements.
+
+ Args:
+ utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``utterance``.
+ right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
+ summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
+ mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
+ left_context_key (torch.Tensor): left context attention key computed from preceding invocation.
+ left_context_val (torch.Tensor): left context attention value computed from preceding invocation.
+
+ Returns:
+ (Tensor, Tensor, Tensor, and Tensor):
+ Tensor
+ output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
+ Tensor
+ updated memory elements, with shape `(M, B, D)`.
+ Tensor
+ attention key computed for left context and utterance.
+ Tensor
+ attention value computed for left context and utterance.
+ """
+ query_dim = right_context.size(0) + utterance.size(0) + summary.size(0)
+ key_dim = right_context.size(0) + utterance.size(0) + mems.size(0) + left_context_key.size(0)
+ attention_mask = torch.zeros(query_dim, key_dim).to(dtype=torch.bool, device=utterance.device)
+ attention_mask[-1, : mems.size(0)] = True
+ output, output_mems, key, value = self._forward_impl(
+ utterance,
+ lengths,
+ right_context,
+ summary,
+ mems,
+ attention_mask,
+ left_context_key=left_context_key,
+ left_context_val=left_context_val,
+ )
+ return (
+ output,
+ output_mems,
+ key[mems.size(0) + right_context.size(0) :],
+ value[mems.size(0) + right_context.size(0) :],
+ )
+
+
+class _EmformerLayer(torch.nn.Module):
+ r"""Emformer layer that constitutes Emformer.
+
+ Args:
+ input_dim (int): input dimension.
+ num_heads (int): number of attention heads.
+ ffn_dim: (int): hidden layer dimension of feedforward network.
+ segment_length (int): length of each input segment.
+ dropout (float, optional): dropout probability. (Default: 0.0)
+ activation (str, optional): activation function to use in feedforward network.
+ Must be one of ("relu", "gelu", "silu"). (Default: "relu")
+ left_context_length (int, optional): length of left context. (Default: 0)
+ max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
+ weight_init_gain (float or None, optional): scale factor to apply when initializing
+ attention module parameters. (Default: ``None``)
+ tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
+ negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ num_heads: int,
+ ffn_dim: int,
+ segment_length: int,
+ dropout: float = 0.0,
+ activation: str = "relu",
+ left_context_length: int = 0,
+ max_memory_size: int = 0,
+ weight_init_gain: Optional[float] = None,
+ tanh_on_mem: bool = False,
+ negative_inf: float = -1e8,
+ ):
+ super().__init__()
+
+ self.attention = _EmformerAttention(
+ input_dim=input_dim,
+ num_heads=num_heads,
+ dropout=dropout,
+ weight_init_gain=weight_init_gain,
+ tanh_on_mem=tanh_on_mem,
+ negative_inf=negative_inf,
+ )
+ self.dropout = torch.nn.Dropout(dropout)
+ self.memory_op = torch.nn.AvgPool1d(kernel_size=segment_length, stride=segment_length, ceil_mode=True)
+
+ activation_module = _get_activation_module(activation)
+ self.pos_ff = torch.nn.Sequential(
+ torch.nn.LayerNorm(input_dim),
+ torch.nn.Linear(input_dim, ffn_dim),
+ activation_module,
+ torch.nn.Dropout(dropout),
+ torch.nn.Linear(ffn_dim, input_dim),
+ torch.nn.Dropout(dropout),
+ )
+ self.layer_norm_input = torch.nn.LayerNorm(input_dim)
+ self.layer_norm_output = torch.nn.LayerNorm(input_dim)
+
+ self.left_context_length = left_context_length
+ self.segment_length = segment_length
+ self.max_memory_size = max_memory_size
+ self.input_dim = input_dim
+
+ self.use_mem = max_memory_size > 0
+
+ def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[torch.Tensor]:
+ empty_memory = torch.zeros(self.max_memory_size, batch_size, self.input_dim, device=device)
+ left_context_key = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
+ left_context_val = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
+ past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
+ return [empty_memory, left_context_key, left_context_val, past_length]
+
+ def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ past_length = state[3][0][0].item()
+ past_left_context_length = min(self.left_context_length, past_length)
+ past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
+ pre_mems = state[0][self.max_memory_size - past_mem_length :]
+ lc_key = state[1][self.left_context_length - past_left_context_length :]
+ lc_val = state[2][self.left_context_length - past_left_context_length :]
+ return pre_mems, lc_key, lc_val
+
+ def _pack_state(
+ self,
+ next_k: torch.Tensor,
+ next_v: torch.Tensor,
+ update_length: int,
+ mems: torch.Tensor,
+ state: List[torch.Tensor],
+ ) -> List[torch.Tensor]:
+ new_k = torch.cat([state[1], next_k])
+ new_v = torch.cat([state[2], next_v])
+ state[0] = torch.cat([state[0], mems])[-self.max_memory_size :]
+ state[1] = new_k[new_k.shape[0] - self.left_context_length :]
+ state[2] = new_v[new_v.shape[0] - self.left_context_length :]
+ state[3] = state[3] + update_length
+ return state
+
+ def _process_attention_output(
+ self,
+ rc_output: torch.Tensor,
+ utterance: torch.Tensor,
+ right_context: torch.Tensor,
+ ) -> torch.Tensor:
+ result = self.dropout(rc_output) + torch.cat([right_context, utterance])
+ result = self.pos_ff(result) + result
+ result = self.layer_norm_output(result)
+ return result
+
+ def _apply_pre_attention_layer_norm(
+ self, utterance: torch.Tensor, right_context: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ layer_norm_input = self.layer_norm_input(torch.cat([right_context, utterance]))
+ return (
+ layer_norm_input[right_context.size(0) :],
+ layer_norm_input[: right_context.size(0)],
+ )
+
+ def _apply_post_attention_ffn(
+ self, rc_output: torch.Tensor, utterance: torch.Tensor, right_context: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ rc_output = self._process_attention_output(rc_output, utterance, right_context)
+ return rc_output[right_context.size(0) :], rc_output[: right_context.size(0)]
+
+ def _apply_attention_forward(
+ self,
+ utterance: torch.Tensor,
+ lengths: torch.Tensor,
+ right_context: torch.Tensor,
+ mems: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if attention_mask is None:
+ raise ValueError("attention_mask must be not None when for_inference is False")
+
+ if self.use_mem:
+ summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
+ else:
+ summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
+ rc_output, next_m = self.attention(
+ utterance=utterance,
+ lengths=lengths,
+ right_context=right_context,
+ summary=summary,
+ mems=mems,
+ attention_mask=attention_mask,
+ )
+ return rc_output, next_m
+
+ def _apply_attention_infer(
+ self,
+ utterance: torch.Tensor,
+ lengths: torch.Tensor,
+ right_context: torch.Tensor,
+ mems: torch.Tensor,
+ state: Optional[List[torch.Tensor]],
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
+ if state is None:
+ state = self._init_state(utterance.size(1), device=utterance.device)
+ pre_mems, lc_key, lc_val = self._unpack_state(state)
+ if self.use_mem:
+ summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
+ summary = summary[:1]
+ else:
+ summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
+ rc_output, next_m, next_k, next_v = self.attention.infer(
+ utterance=utterance,
+ lengths=lengths,
+ right_context=right_context,
+ summary=summary,
+ mems=pre_mems,
+ left_context_key=lc_key,
+ left_context_val=lc_val,
+ )
+ state = self._pack_state(next_k, next_v, utterance.size(0), mems, state)
+ return rc_output, next_m, state
+
+ def forward(
+ self,
+ utterance: torch.Tensor,
+ lengths: torch.Tensor,
+ right_context: torch.Tensor,
+ mems: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ r"""Forward pass for training.
+
+ B: batch size;
+ D: feature dimension of each frame;
+ T: number of utterance frames;
+ R: number of right context frames;
+ M: number of memory elements.
+
+ Args:
+ utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``utterance``.
+ right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
+ mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
+ attention_mask (torch.Tensor): attention mask for underlying attention module.
+
+ Returns:
+ (Tensor, Tensor, Tensor):
+ Tensor
+ encoded utterance frames, with shape `(T, B, D)`.
+ Tensor
+ updated right context frames, with shape `(R, B, D)`.
+ Tensor
+ updated memory elements, with shape `(M, B, D)`.
+ """
+ (
+ layer_norm_utterance,
+ layer_norm_right_context,
+ ) = self._apply_pre_attention_layer_norm(utterance, right_context)
+ rc_output, output_mems = self._apply_attention_forward(
+ layer_norm_utterance,
+ lengths,
+ layer_norm_right_context,
+ mems,
+ attention_mask,
+ )
+ output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
+ return output_utterance, output_right_context, output_mems
+
+ @torch.jit.export
+ def infer(
+ self,
+ utterance: torch.Tensor,
+ lengths: torch.Tensor,
+ right_context: torch.Tensor,
+ state: Optional[List[torch.Tensor]],
+ mems: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
+ r"""Forward pass for inference.
+
+ B: batch size;
+ D: feature dimension of each frame;
+ T: number of utterance frames;
+ R: number of right context frames;
+ M: number of memory elements.
+
+ Args:
+ utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``utterance``.
+ right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
+ state (List[torch.Tensor] or None): list of tensors representing layer internal state
+ generated in preceding invocation of ``infer``.
+ mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
+
+ Returns:
+ (Tensor, Tensor, List[torch.Tensor], Tensor):
+ Tensor
+ encoded utterance frames, with shape `(T, B, D)`.
+ Tensor
+ updated right context frames, with shape `(R, B, D)`.
+ List[Tensor]
+ list of tensors representing layer internal state
+ generated in current invocation of ``infer``.
+ Tensor
+ updated memory elements, with shape `(M, B, D)`.
+ """
+ (
+ layer_norm_utterance,
+ layer_norm_right_context,
+ ) = self._apply_pre_attention_layer_norm(utterance, right_context)
+ rc_output, output_mems, output_state = self._apply_attention_infer(
+ layer_norm_utterance, lengths, layer_norm_right_context, mems, state
+ )
+ output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
+ return output_utterance, output_right_context, output_state, output_mems
+
+
+class _EmformerImpl(torch.nn.Module):
+ def __init__(
+ self,
+ emformer_layers: torch.nn.ModuleList,
+ segment_length: int,
+ left_context_length: int = 0,
+ right_context_length: int = 0,
+ max_memory_size: int = 0,
+ ):
+ super().__init__()
+
+ self.use_mem = max_memory_size > 0
+ self.memory_op = torch.nn.AvgPool1d(
+ kernel_size=segment_length,
+ stride=segment_length,
+ ceil_mode=True,
+ )
+ self.emformer_layers = emformer_layers
+ self.left_context_length = left_context_length
+ self.right_context_length = right_context_length
+ self.segment_length = segment_length
+ self.max_memory_size = max_memory_size
+
+ def _gen_right_context(self, input: torch.Tensor) -> torch.Tensor:
+ T = input.shape[0]
+ num_segs = math.ceil((T - self.right_context_length) / self.segment_length)
+ right_context_blocks = []
+ for seg_idx in range(num_segs - 1):
+ start = (seg_idx + 1) * self.segment_length
+ end = start + self.right_context_length
+ right_context_blocks.append(input[start:end])
+ right_context_blocks.append(input[T - self.right_context_length :])
+ return torch.cat(right_context_blocks)
+
+ def _gen_attention_mask_col_widths(self, seg_idx: int, utterance_length: int) -> List[int]:
+ num_segs = math.ceil(utterance_length / self.segment_length)
+ rc = self.right_context_length
+ lc = self.left_context_length
+ rc_start = seg_idx * rc
+ rc_end = rc_start + rc
+ seg_start = max(seg_idx * self.segment_length - lc, 0)
+ seg_end = min((seg_idx + 1) * self.segment_length, utterance_length)
+ rc_length = self.right_context_length * num_segs
+
+ if self.use_mem:
+ m_start = max(seg_idx - self.max_memory_size, 0)
+ mem_length = num_segs - 1
+ col_widths = [
+ m_start, # before memory
+ seg_idx - m_start, # memory
+ mem_length - seg_idx, # after memory
+ rc_start, # before right context
+ rc, # right context
+ rc_length - rc_end, # after right context
+ seg_start, # before query segment
+ seg_end - seg_start, # query segment
+ utterance_length - seg_end, # after query segment
+ ]
+ else:
+ col_widths = [
+ rc_start, # before right context
+ rc, # right context
+ rc_length - rc_end, # after right context
+ seg_start, # before query segment
+ seg_end - seg_start, # query segment
+ utterance_length - seg_end, # after query segment
+ ]
+
+ return col_widths
+
+ def _gen_attention_mask(self, input: torch.Tensor) -> torch.Tensor:
+ utterance_length = input.size(0)
+ num_segs = math.ceil(utterance_length / self.segment_length)
+
+ rc_mask = []
+ query_mask = []
+ summary_mask = []
+
+ if self.use_mem:
+ num_cols = 9
+ # memory, right context, query segment
+ rc_q_cols_mask = [idx in [1, 4, 7] for idx in range(num_cols)]
+ # right context, query segment
+ s_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
+ masks_to_concat = [rc_mask, query_mask, summary_mask]
+ else:
+ num_cols = 6
+ # right context, query segment
+ rc_q_cols_mask = [idx in [1, 4] for idx in range(num_cols)]
+ s_cols_mask = None
+ masks_to_concat = [rc_mask, query_mask]
+
+ for seg_idx in range(num_segs):
+ col_widths = self._gen_attention_mask_col_widths(seg_idx, utterance_length)
+
+ rc_mask_block = _gen_attention_mask_block(
+ col_widths, rc_q_cols_mask, self.right_context_length, input.device
+ )
+ rc_mask.append(rc_mask_block)
+
+ query_mask_block = _gen_attention_mask_block(
+ col_widths,
+ rc_q_cols_mask,
+ min(
+ self.segment_length,
+ utterance_length - seg_idx * self.segment_length,
+ ),
+ input.device,
+ )
+ query_mask.append(query_mask_block)
+
+ if s_cols_mask is not None:
+ summary_mask_block = _gen_attention_mask_block(col_widths, s_cols_mask, 1, input.device)
+ summary_mask.append(summary_mask_block)
+
+ attention_mask = (1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])).to(torch.bool)
+ return attention_mask
+
+ def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""Forward pass for training and non-streaming inference.
+
+ B: batch size;
+ T: max number of input frames in batch;
+ D: feature dimension of each frame.
+
+ Args:
+ input (torch.Tensor): utterance frames right-padded with right context frames, with
+ shape `(B, T + right_context_length, D)`.
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid utterance frames for i-th batch element in ``input``.
+
+ Returns:
+ (Tensor, Tensor):
+ Tensor
+ output frames, with shape `(B, T, D)`.
+ Tensor
+ output lengths, with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in output frames.
+ """
+ input = input.permute(1, 0, 2)
+ right_context = self._gen_right_context(input)
+ utterance = input[: input.size(0) - self.right_context_length]
+ attention_mask = self._gen_attention_mask(utterance)
+ mems = (
+ self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1]
+ if self.use_mem
+ else torch.empty(0).to(dtype=input.dtype, device=input.device)
+ )
+ output = utterance
+ for layer in self.emformer_layers:
+ output, right_context, mems = layer(output, lengths, right_context, mems, attention_mask)
+ return output.permute(1, 0, 2), lengths
+
+ @torch.jit.export
+ def infer(
+ self,
+ input: torch.Tensor,
+ lengths: torch.Tensor,
+ states: Optional[List[List[torch.Tensor]]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
+ r"""Forward pass for streaming inference.
+
+ B: batch size;
+ D: feature dimension of each frame.
+
+ Args:
+ input (torch.Tensor): utterance frames right-padded with right context frames, with
+ shape `(B, segment_length + right_context_length, D)`.
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``input``.
+ states (List[List[torch.Tensor]] or None, optional): list of lists of tensors
+ representing internal state generated in preceding invocation of ``infer``. (Default: ``None``)
+
+ Returns:
+ (Tensor, Tensor, List[List[Tensor]]):
+ Tensor
+ output frames, with shape `(B, segment_length, D)`.
+ Tensor
+ output lengths, with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in output frames.
+ List[List[Tensor]]
+ output states; list of lists of tensors representing internal state
+ generated in current invocation of ``infer``.
+ """
+ if input.size(1) != self.segment_length + self.right_context_length:
+ raise ValueError(
+ "Per configured segment_length and right_context_length"
+ f", expected size of {self.segment_length + self.right_context_length} for dimension 1 of input"
+ f", but got {input.size(1)}."
+ )
+ input = input.permute(1, 0, 2)
+ right_context_start_idx = input.size(0) - self.right_context_length
+ right_context = input[right_context_start_idx:]
+ utterance = input[:right_context_start_idx]
+ output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
+ mems = (
+ self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
+ if self.use_mem
+ else torch.empty(0).to(dtype=input.dtype, device=input.device)
+ )
+ output = utterance
+ output_states: List[List[torch.Tensor]] = []
+ for layer_idx, layer in enumerate(self.emformer_layers):
+ output, right_context, output_state, mems = layer.infer(
+ output,
+ output_lengths,
+ right_context,
+ None if states is None else states[layer_idx],
+ mems,
+ )
+ output_states.append(output_state)
+
+ return output.permute(1, 0, 2), output_lengths, output_states
+
+
+class Emformer(_EmformerImpl):
+ r"""Emformer architecture introduced in
+ *Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition*
+ :cite:`shi2021emformer`.
+
+ See Also:
+ * :func:`~torchaudio.models.emformer_rnnt_model`,
+ :func:`~torchaudio.models.emformer_rnnt_base`: factory functions.
+ * :class:`torchaudio.pipelines.RNNTBundle`: ASR pipelines with pretrained model.
+
+ Args:
+ input_dim (int): input dimension.
+ num_heads (int): number of attention heads in each Emformer layer.
+ ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
+ num_layers (int): number of Emformer layers to instantiate.
+ segment_length (int): length of each input segment.
+ dropout (float, optional): dropout probability. (Default: 0.0)
+ activation (str, optional): activation function to use in each Emformer layer's
+ feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
+ left_context_length (int, optional): length of left context. (Default: 0)
+ right_context_length (int, optional): length of right context. (Default: 0)
+ max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
+ weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling
+ strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
+ tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
+ negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
+
+ Examples:
+ >>> emformer = Emformer(512, 8, 2048, 20, 4, right_context_length=1)
+ >>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim
+ >>> lengths = torch.randint(1, 200, (128,)) # batch
+ >>> output, lengths = emformer(input, lengths)
+ >>> input = torch.rand(128, 5, 512)
+ >>> lengths = torch.ones(128) * 5
+ >>> output, lengths, states = emformer.infer(input, lengths, None)
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ num_heads: int,
+ ffn_dim: int,
+ num_layers: int,
+ segment_length: int,
+ dropout: float = 0.0,
+ activation: str = "relu",
+ left_context_length: int = 0,
+ right_context_length: int = 0,
+ max_memory_size: int = 0,
+ weight_init_scale_strategy: Optional[str] = "depthwise",
+ tanh_on_mem: bool = False,
+ negative_inf: float = -1e8,
+ ):
+ weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers)
+ emformer_layers = torch.nn.ModuleList(
+ [
+ _EmformerLayer(
+ input_dim,
+ num_heads,
+ ffn_dim,
+ segment_length,
+ dropout=dropout,
+ activation=activation,
+ left_context_length=left_context_length,
+ max_memory_size=max_memory_size,
+ weight_init_gain=weight_init_gains[layer_idx],
+ tanh_on_mem=tanh_on_mem,
+ negative_inf=negative_inf,
+ )
+ for layer_idx in range(num_layers)
+ ]
+ )
+ super().__init__(
+ emformer_layers,
+ segment_length,
+ left_context_length=left_context_length,
+ right_context_length=right_context_length,
+ max_memory_size=max_memory_size,
+ )
diff --git a/MLPY/Lib/site-packages/torchaudio/models/rnnt.py b/MLPY/Lib/site-packages/torchaudio/models/rnnt.py
new file mode 100644
index 0000000000000000000000000000000000000000..659c7b93442095ad3d7c86e38e328094ce552d0c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/rnnt.py
@@ -0,0 +1,816 @@
+from abc import ABC, abstractmethod
+from typing import List, Optional, Tuple
+
+import torch
+from torchaudio.models import Emformer
+
+
+__all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"]
+
+
+class _TimeReduction(torch.nn.Module):
+ r"""Coalesces frames along time dimension into a
+ fewer number of frames with higher feature dimensionality.
+
+ Args:
+ stride (int): number of frames to merge for each output frame.
+ """
+
+ def __init__(self, stride: int) -> None:
+ super().__init__()
+ self.stride = stride
+
+ def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""Forward pass.
+
+ B: batch size;
+ T: maximum input sequence length in batch;
+ D: feature dimension of each input sequence frame.
+
+ Args:
+ input (torch.Tensor): input sequences, with shape `(B, T, D)`.
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``input``.
+
+ Returns:
+ (torch.Tensor, torch.Tensor):
+ torch.Tensor
+ output sequences, with shape
+ `(B, T // stride, D * stride)`
+ torch.Tensor
+ output lengths, with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in output sequences.
+ """
+ B, T, D = input.shape
+ num_frames = T - (T % self.stride)
+ input = input[:, :num_frames, :]
+ lengths = lengths.div(self.stride, rounding_mode="trunc")
+ T_max = num_frames // self.stride
+
+ output = input.reshape(B, T_max, D * self.stride)
+ output = output.contiguous()
+ return output, lengths
+
+
+class _CustomLSTM(torch.nn.Module):
+ r"""Custom long-short-term memory (LSTM) block that applies layer normalization
+ to internal nodes.
+
+ Args:
+ input_dim (int): input dimension.
+ hidden_dim (int): hidden dimension.
+ layer_norm (bool, optional): if ``True``, enables layer normalization. (Default: ``False``)
+ layer_norm_epsilon (float, optional): value of epsilon to use in
+ layer normalization layers (Default: 1e-5)
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ layer_norm: bool = False,
+ layer_norm_epsilon: float = 1e-5,
+ ) -> None:
+ super().__init__()
+ self.x2g = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=(not layer_norm))
+ self.p2g = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=False)
+ if layer_norm:
+ self.c_norm = torch.nn.LayerNorm(hidden_dim, eps=layer_norm_epsilon)
+ self.g_norm = torch.nn.LayerNorm(4 * hidden_dim, eps=layer_norm_epsilon)
+ else:
+ self.c_norm = torch.nn.Identity()
+ self.g_norm = torch.nn.Identity()
+
+ self.hidden_dim = hidden_dim
+
+ def forward(
+ self, input: torch.Tensor, state: Optional[List[torch.Tensor]]
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ r"""Forward pass.
+
+ B: batch size;
+ T: maximum sequence length in batch;
+ D: feature dimension of each input sequence element.
+
+ Args:
+ input (torch.Tensor): with shape `(T, B, D)`.
+ state (List[torch.Tensor] or None): list of tensors
+ representing internal state generated in preceding invocation
+ of ``forward``.
+
+ Returns:
+ (torch.Tensor, List[torch.Tensor]):
+ torch.Tensor
+ output, with shape `(T, B, hidden_dim)`.
+ List[torch.Tensor]
+ list of tensors representing internal state generated
+ in current invocation of ``forward``.
+ """
+ if state is None:
+ B = input.size(1)
+ h = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
+ c = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
+ else:
+ h, c = state
+
+ gated_input = self.x2g(input)
+ outputs = []
+ for gates in gated_input.unbind(0):
+ gates = gates + self.p2g(h)
+ gates = self.g_norm(gates)
+ input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, 1)
+ input_gate = input_gate.sigmoid()
+ forget_gate = forget_gate.sigmoid()
+ cell_gate = cell_gate.tanh()
+ output_gate = output_gate.sigmoid()
+ c = forget_gate * c + input_gate * cell_gate
+ c = self.c_norm(c)
+ h = output_gate * c.tanh()
+ outputs.append(h)
+
+ output = torch.stack(outputs, dim=0)
+ state = [h, c]
+
+ return output, state
+
+
+class _Transcriber(ABC):
+ @abstractmethod
+ def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ pass
+
+ @abstractmethod
+ def infer(
+ self,
+ input: torch.Tensor,
+ lengths: torch.Tensor,
+ states: Optional[List[List[torch.Tensor]]],
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
+ pass
+
+
+class _EmformerEncoder(torch.nn.Module, _Transcriber):
+ r"""Emformer-based recurrent neural network transducer (RNN-T) encoder (transcription network).
+
+ Args:
+ input_dim (int): feature dimension of each input sequence element.
+ output_dim (int): feature dimension of each output sequence element.
+ segment_length (int): length of input segment expressed as number of frames.
+ right_context_length (int): length of right context expressed as number of frames.
+ time_reduction_input_dim (int): dimension to scale each element in input sequences to
+ prior to applying time reduction block.
+ time_reduction_stride (int): factor by which to reduce length of input sequence.
+ transformer_num_heads (int): number of attention heads in each Emformer layer.
+ transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
+ transformer_num_layers (int): number of Emformer layers to instantiate.
+ transformer_left_context_length (int): length of left context.
+ transformer_dropout (float, optional): transformer dropout probability. (Default: 0.0)
+ transformer_activation (str, optional): activation function to use in each Emformer layer's
+ feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
+ transformer_max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
+ transformer_weight_init_scale_strategy (str, optional): per-layer weight initialization scaling
+ strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
+ transformer_tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
+ """
+
+ def __init__(
+ self,
+ *,
+ input_dim: int,
+ output_dim: int,
+ segment_length: int,
+ right_context_length: int,
+ time_reduction_input_dim: int,
+ time_reduction_stride: int,
+ transformer_num_heads: int,
+ transformer_ffn_dim: int,
+ transformer_num_layers: int,
+ transformer_left_context_length: int,
+ transformer_dropout: float = 0.0,
+ transformer_activation: str = "relu",
+ transformer_max_memory_size: int = 0,
+ transformer_weight_init_scale_strategy: str = "depthwise",
+ transformer_tanh_on_mem: bool = False,
+ ) -> None:
+ super().__init__()
+ self.input_linear = torch.nn.Linear(
+ input_dim,
+ time_reduction_input_dim,
+ bias=False,
+ )
+ self.time_reduction = _TimeReduction(time_reduction_stride)
+ transformer_input_dim = time_reduction_input_dim * time_reduction_stride
+ self.transformer = Emformer(
+ transformer_input_dim,
+ transformer_num_heads,
+ transformer_ffn_dim,
+ transformer_num_layers,
+ segment_length // time_reduction_stride,
+ dropout=transformer_dropout,
+ activation=transformer_activation,
+ left_context_length=transformer_left_context_length,
+ right_context_length=right_context_length // time_reduction_stride,
+ max_memory_size=transformer_max_memory_size,
+ weight_init_scale_strategy=transformer_weight_init_scale_strategy,
+ tanh_on_mem=transformer_tanh_on_mem,
+ )
+ self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim)
+ self.layer_norm = torch.nn.LayerNorm(output_dim)
+
+ def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""Forward pass for training.
+
+ B: batch size;
+ T: maximum input sequence length in batch;
+ D: feature dimension of each input sequence frame (input_dim).
+
+ Args:
+ input (torch.Tensor): input frame sequences right-padded with right context, with
+ shape `(B, T + right context length, D)`.
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``input``.
+
+ Returns:
+ (torch.Tensor, torch.Tensor):
+ torch.Tensor
+ output frame sequences, with
+ shape `(B, T // time_reduction_stride, output_dim)`.
+ torch.Tensor
+ output input lengths, with shape `(B,)` and i-th element representing
+ number of valid elements for i-th batch element in output frame sequences.
+ """
+ input_linear_out = self.input_linear(input)
+ time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
+ transformer_out, transformer_lengths = self.transformer(time_reduction_out, time_reduction_lengths)
+ output_linear_out = self.output_linear(transformer_out)
+ layer_norm_out = self.layer_norm(output_linear_out)
+ return layer_norm_out, transformer_lengths
+
+ @torch.jit.export
+ def infer(
+ self,
+ input: torch.Tensor,
+ lengths: torch.Tensor,
+ states: Optional[List[List[torch.Tensor]]],
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
+ r"""Forward pass for inference.
+
+ B: batch size;
+ T: maximum input sequence segment length in batch;
+ D: feature dimension of each input sequence frame (input_dim).
+
+ Args:
+ input (torch.Tensor): input frame sequence segments right-padded with right context, with
+ shape `(B, T + right context length, D)`.
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``input``.
+ state (List[List[torch.Tensor]] or None): list of lists of tensors
+ representing internal state generated in preceding invocation
+ of ``infer``.
+
+ Returns:
+ (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
+ torch.Tensor
+ output frame sequences, with
+ shape `(B, T // time_reduction_stride, output_dim)`.
+ torch.Tensor
+ output input lengths, with shape `(B,)` and i-th element representing
+ number of valid elements for i-th batch element in output.
+ List[List[torch.Tensor]]
+ output states; list of lists of tensors
+ representing internal state generated in current invocation
+ of ``infer``.
+ """
+ input_linear_out = self.input_linear(input)
+ time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
+ (
+ transformer_out,
+ transformer_lengths,
+ transformer_states,
+ ) = self.transformer.infer(time_reduction_out, time_reduction_lengths, states)
+ output_linear_out = self.output_linear(transformer_out)
+ layer_norm_out = self.layer_norm(output_linear_out)
+ return layer_norm_out, transformer_lengths, transformer_states
+
+
+class _Predictor(torch.nn.Module):
+ r"""Recurrent neural network transducer (RNN-T) prediction network.
+
+ Args:
+ num_symbols (int): size of target token lexicon.
+ output_dim (int): feature dimension of each output sequence element.
+ symbol_embedding_dim (int): dimension of each target token embedding.
+ num_lstm_layers (int): number of LSTM layers to instantiate.
+ lstm_hidden_dim (int): output dimension of each LSTM layer.
+ lstm_layer_norm (bool, optional): if ``True``, enables layer normalization
+ for LSTM layers. (Default: ``False``)
+ lstm_layer_norm_epsilon (float, optional): value of epsilon to use in
+ LSTM layer normalization layers. (Default: 1e-5)
+ lstm_dropout (float, optional): LSTM dropout probability. (Default: 0.0)
+
+ """
+
+ def __init__(
+ self,
+ num_symbols: int,
+ output_dim: int,
+ symbol_embedding_dim: int,
+ num_lstm_layers: int,
+ lstm_hidden_dim: int,
+ lstm_layer_norm: bool = False,
+ lstm_layer_norm_epsilon: float = 1e-5,
+ lstm_dropout: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.embedding = torch.nn.Embedding(num_symbols, symbol_embedding_dim)
+ self.input_layer_norm = torch.nn.LayerNorm(symbol_embedding_dim)
+ self.lstm_layers = torch.nn.ModuleList(
+ [
+ _CustomLSTM(
+ symbol_embedding_dim if idx == 0 else lstm_hidden_dim,
+ lstm_hidden_dim,
+ layer_norm=lstm_layer_norm,
+ layer_norm_epsilon=lstm_layer_norm_epsilon,
+ )
+ for idx in range(num_lstm_layers)
+ ]
+ )
+ self.dropout = torch.nn.Dropout(p=lstm_dropout)
+ self.linear = torch.nn.Linear(lstm_hidden_dim, output_dim)
+ self.output_layer_norm = torch.nn.LayerNorm(output_dim)
+
+ self.lstm_dropout = lstm_dropout
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ lengths: torch.Tensor,
+ state: Optional[List[List[torch.Tensor]]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
+ r"""Forward pass.
+
+ B: batch size;
+ U: maximum sequence length in batch;
+ D: feature dimension of each input sequence element.
+
+ Args:
+ input (torch.Tensor): target sequences, with shape `(B, U)` and each element
+ mapping to a target symbol, i.e. in range `[0, num_symbols)`.
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``input``.
+ state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
+ representing internal state generated in preceding invocation
+ of ``forward``. (Default: ``None``)
+
+ Returns:
+ (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
+ torch.Tensor
+ output encoding sequences, with shape `(B, U, output_dim)`
+ torch.Tensor
+ output lengths, with shape `(B,)` and i-th element representing
+ number of valid elements for i-th batch element in output encoding sequences.
+ List[List[torch.Tensor]]
+ output states; list of lists of tensors
+ representing internal state generated in current invocation of ``forward``.
+ """
+ input_tb = input.permute(1, 0)
+ embedding_out = self.embedding(input_tb)
+ input_layer_norm_out = self.input_layer_norm(embedding_out)
+
+ lstm_out = input_layer_norm_out
+ state_out: List[List[torch.Tensor]] = []
+ for layer_idx, lstm in enumerate(self.lstm_layers):
+ lstm_out, lstm_state_out = lstm(lstm_out, None if state is None else state[layer_idx])
+ lstm_out = self.dropout(lstm_out)
+ state_out.append(lstm_state_out)
+
+ linear_out = self.linear(lstm_out)
+ output_layer_norm_out = self.output_layer_norm(linear_out)
+ return output_layer_norm_out.permute(1, 0, 2), lengths, state_out
+
+
+class _Joiner(torch.nn.Module):
+ r"""Recurrent neural network transducer (RNN-T) joint network.
+
+ Args:
+ input_dim (int): source and target input dimension.
+ output_dim (int): output dimension.
+ activation (str, optional): activation function to use in the joiner.
+ Must be one of ("relu", "tanh"). (Default: "relu")
+
+ """
+
+ def __init__(self, input_dim: int, output_dim: int, activation: str = "relu") -> None:
+ super().__init__()
+ self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
+ if activation == "relu":
+ self.activation = torch.nn.ReLU()
+ elif activation == "tanh":
+ self.activation = torch.nn.Tanh()
+ else:
+ raise ValueError(f"Unsupported activation {activation}")
+
+ def forward(
+ self,
+ source_encodings: torch.Tensor,
+ source_lengths: torch.Tensor,
+ target_encodings: torch.Tensor,
+ target_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ r"""Forward pass for training.
+
+ B: batch size;
+ T: maximum source sequence length in batch;
+ U: maximum target sequence length in batch;
+ D: dimension of each source and target sequence encoding.
+
+ Args:
+ source_encodings (torch.Tensor): source encoding sequences, with
+ shape `(B, T, D)`.
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ valid sequence length of i-th batch element in ``source_encodings``.
+ target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
+ target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ valid sequence length of i-th batch element in ``target_encodings``.
+
+ Returns:
+ (torch.Tensor, torch.Tensor, torch.Tensor):
+ torch.Tensor
+ joint network output, with shape `(B, T, U, output_dim)`.
+ torch.Tensor
+ output source lengths, with shape `(B,)` and i-th element representing
+ number of valid elements along dim 1 for i-th batch element in joint network output.
+ torch.Tensor
+ output target lengths, with shape `(B,)` and i-th element representing
+ number of valid elements along dim 2 for i-th batch element in joint network output.
+ """
+ joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
+ activation_out = self.activation(joint_encodings)
+ output = self.linear(activation_out)
+ return output, source_lengths, target_lengths
+
+
+class RNNT(torch.nn.Module):
+ r"""torchaudio.models.RNNT()
+
+ Recurrent neural network transducer (RNN-T) model.
+
+ Note:
+ To build the model, please use one of the factory functions.
+
+ See Also:
+ :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pre-trained models.
+
+ Args:
+ transcriber (torch.nn.Module): transcription network.
+ predictor (torch.nn.Module): prediction network.
+ joiner (torch.nn.Module): joint network.
+ """
+
+ def __init__(self, transcriber: _Transcriber, predictor: _Predictor, joiner: _Joiner) -> None:
+ super().__init__()
+ self.transcriber = transcriber
+ self.predictor = predictor
+ self.joiner = joiner
+
+ def forward(
+ self,
+ sources: torch.Tensor,
+ source_lengths: torch.Tensor,
+ targets: torch.Tensor,
+ target_lengths: torch.Tensor,
+ predictor_state: Optional[List[List[torch.Tensor]]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
+ r"""Forward pass for training.
+
+ B: batch size;
+ T: maximum source sequence length in batch;
+ U: maximum target sequence length in batch;
+ D: feature dimension of each source sequence element.
+
+ Args:
+ sources (torch.Tensor): source frame sequences right-padded with right context, with
+ shape `(B, T, D)`.
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``sources``.
+ targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
+ mapping to a target symbol.
+ target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``targets``.
+ predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
+ representing prediction network internal state generated in preceding invocation
+ of ``forward``. (Default: ``None``)
+
+ Returns:
+ (torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
+ torch.Tensor
+ joint network output, with shape
+ `(B, max output source length, max output target length, output_dim (number of target symbols))`.
+ torch.Tensor
+ output source lengths, with shape `(B,)` and i-th element representing
+ number of valid elements along dim 1 for i-th batch element in joint network output.
+ torch.Tensor
+ output target lengths, with shape `(B,)` and i-th element representing
+ number of valid elements along dim 2 for i-th batch element in joint network output.
+ List[List[torch.Tensor]]
+ output states; list of lists of tensors
+ representing prediction network internal state generated in current invocation
+ of ``forward``.
+ """
+ source_encodings, source_lengths = self.transcriber(
+ input=sources,
+ lengths=source_lengths,
+ )
+ target_encodings, target_lengths, predictor_state = self.predictor(
+ input=targets,
+ lengths=target_lengths,
+ state=predictor_state,
+ )
+ output, source_lengths, target_lengths = self.joiner(
+ source_encodings=source_encodings,
+ source_lengths=source_lengths,
+ target_encodings=target_encodings,
+ target_lengths=target_lengths,
+ )
+
+ return (
+ output,
+ source_lengths,
+ target_lengths,
+ predictor_state,
+ )
+
+ @torch.jit.export
+ def transcribe_streaming(
+ self,
+ sources: torch.Tensor,
+ source_lengths: torch.Tensor,
+ state: Optional[List[List[torch.Tensor]]],
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
+ r"""Applies transcription network to sources in streaming mode.
+
+ B: batch size;
+ T: maximum source sequence segment length in batch;
+ D: feature dimension of each source sequence frame.
+
+ Args:
+ sources (torch.Tensor): source frame sequence segments right-padded with right context, with
+ shape `(B, T + right context length, D)`.
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``sources``.
+ state (List[List[torch.Tensor]] or None): list of lists of tensors
+ representing transcription network internal state generated in preceding invocation
+ of ``transcribe_streaming``.
+
+ Returns:
+ (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
+ torch.Tensor
+ output frame sequences, with
+ shape `(B, T // time_reduction_stride, output_dim)`.
+ torch.Tensor
+ output lengths, with shape `(B,)` and i-th element representing
+ number of valid elements for i-th batch element in output.
+ List[List[torch.Tensor]]
+ output states; list of lists of tensors
+ representing transcription network internal state generated in current invocation
+ of ``transcribe_streaming``.
+ """
+ return self.transcriber.infer(sources, source_lengths, state)
+
+ @torch.jit.export
+ def transcribe(
+ self,
+ sources: torch.Tensor,
+ source_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""Applies transcription network to sources in non-streaming mode.
+
+ B: batch size;
+ T: maximum source sequence length in batch;
+ D: feature dimension of each source sequence frame.
+
+ Args:
+ sources (torch.Tensor): source frame sequences right-padded with right context, with
+ shape `(B, T + right context length, D)`.
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``sources``.
+
+ Returns:
+ (torch.Tensor, torch.Tensor):
+ torch.Tensor
+ output frame sequences, with
+ shape `(B, T // time_reduction_stride, output_dim)`.
+ torch.Tensor
+ output lengths, with shape `(B,)` and i-th element representing
+ number of valid elements for i-th batch element in output frame sequences.
+ """
+ return self.transcriber(sources, source_lengths)
+
+ @torch.jit.export
+ def predict(
+ self,
+ targets: torch.Tensor,
+ target_lengths: torch.Tensor,
+ state: Optional[List[List[torch.Tensor]]],
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
+ r"""Applies prediction network to targets.
+
+ B: batch size;
+ U: maximum target sequence length in batch;
+ D: feature dimension of each target sequence frame.
+
+ Args:
+ targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
+ mapping to a target symbol, i.e. in range `[0, num_symbols)`.
+ target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``targets``.
+ state (List[List[torch.Tensor]] or None): list of lists of tensors
+ representing internal state generated in preceding invocation
+ of ``predict``.
+
+ Returns:
+ (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
+ torch.Tensor
+ output frame sequences, with shape `(B, U, output_dim)`.
+ torch.Tensor
+ output lengths, with shape `(B,)` and i-th element representing
+ number of valid elements for i-th batch element in output.
+ List[List[torch.Tensor]]
+ output states; list of lists of tensors
+ representing internal state generated in current invocation of ``predict``.
+ """
+ return self.predictor(input=targets, lengths=target_lengths, state=state)
+
+ @torch.jit.export
+ def join(
+ self,
+ source_encodings: torch.Tensor,
+ source_lengths: torch.Tensor,
+ target_encodings: torch.Tensor,
+ target_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ r"""Applies joint network to source and target encodings.
+
+ B: batch size;
+ T: maximum source sequence length in batch;
+ U: maximum target sequence length in batch;
+ D: dimension of each source and target sequence encoding.
+
+ Args:
+ source_encodings (torch.Tensor): source encoding sequences, with
+ shape `(B, T, D)`.
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ valid sequence length of i-th batch element in ``source_encodings``.
+ target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
+ target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ valid sequence length of i-th batch element in ``target_encodings``.
+
+ Returns:
+ (torch.Tensor, torch.Tensor, torch.Tensor):
+ torch.Tensor
+ joint network output, with shape `(B, T, U, output_dim)`.
+ torch.Tensor
+ output source lengths, with shape `(B,)` and i-th element representing
+ number of valid elements along dim 1 for i-th batch element in joint network output.
+ torch.Tensor
+ output target lengths, with shape `(B,)` and i-th element representing
+ number of valid elements along dim 2 for i-th batch element in joint network output.
+ """
+ output, source_lengths, target_lengths = self.joiner(
+ source_encodings=source_encodings,
+ source_lengths=source_lengths,
+ target_encodings=target_encodings,
+ target_lengths=target_lengths,
+ )
+ return output, source_lengths, target_lengths
+
+
+def emformer_rnnt_model(
+ *,
+ input_dim: int,
+ encoding_dim: int,
+ num_symbols: int,
+ segment_length: int,
+ right_context_length: int,
+ time_reduction_input_dim: int,
+ time_reduction_stride: int,
+ transformer_num_heads: int,
+ transformer_ffn_dim: int,
+ transformer_num_layers: int,
+ transformer_dropout: float,
+ transformer_activation: str,
+ transformer_left_context_length: int,
+ transformer_max_memory_size: int,
+ transformer_weight_init_scale_strategy: str,
+ transformer_tanh_on_mem: bool,
+ symbol_embedding_dim: int,
+ num_lstm_layers: int,
+ lstm_layer_norm: bool,
+ lstm_layer_norm_epsilon: float,
+ lstm_dropout: float,
+) -> RNNT:
+ r"""Builds Emformer-based :class:`~torchaudio.models.RNNT`.
+
+ Note:
+ For non-streaming inference, the expectation is for `transcribe` to be called on input
+ sequences right-concatenated with `right_context_length` frames.
+
+ For streaming inference, the expectation is for `transcribe_streaming` to be called
+ on input chunks comprising `segment_length` frames right-concatenated with `right_context_length`
+ frames.
+
+ Args:
+ input_dim (int): dimension of input sequence frames passed to transcription network.
+ encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
+ passed to joint network.
+ num_symbols (int): cardinality of set of target tokens.
+ segment_length (int): length of input segment expressed as number of frames.
+ right_context_length (int): length of right context expressed as number of frames.
+ time_reduction_input_dim (int): dimension to scale each element in input sequences to
+ prior to applying time reduction block.
+ time_reduction_stride (int): factor by which to reduce length of input sequence.
+ transformer_num_heads (int): number of attention heads in each Emformer layer.
+ transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
+ transformer_num_layers (int): number of Emformer layers to instantiate.
+ transformer_left_context_length (int): length of left context considered by Emformer.
+ transformer_dropout (float): Emformer dropout probability.
+ transformer_activation (str): activation function to use in each Emformer layer's
+ feedforward network. Must be one of ("relu", "gelu", "silu").
+ transformer_max_memory_size (int): maximum number of memory elements to use.
+ transformer_weight_init_scale_strategy (str): per-layer weight initialization scaling
+ strategy. Must be one of ("depthwise", "constant", ``None``).
+ transformer_tanh_on_mem (bool): if ``True``, applies tanh to memory elements.
+ symbol_embedding_dim (int): dimension of each target token embedding.
+ num_lstm_layers (int): number of LSTM layers to instantiate.
+ lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
+ lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
+ lstm_dropout (float): LSTM dropout probability.
+
+ Returns:
+ RNNT:
+ Emformer RNN-T model.
+ """
+ encoder = _EmformerEncoder(
+ input_dim=input_dim,
+ output_dim=encoding_dim,
+ segment_length=segment_length,
+ right_context_length=right_context_length,
+ time_reduction_input_dim=time_reduction_input_dim,
+ time_reduction_stride=time_reduction_stride,
+ transformer_num_heads=transformer_num_heads,
+ transformer_ffn_dim=transformer_ffn_dim,
+ transformer_num_layers=transformer_num_layers,
+ transformer_dropout=transformer_dropout,
+ transformer_activation=transformer_activation,
+ transformer_left_context_length=transformer_left_context_length,
+ transformer_max_memory_size=transformer_max_memory_size,
+ transformer_weight_init_scale_strategy=transformer_weight_init_scale_strategy,
+ transformer_tanh_on_mem=transformer_tanh_on_mem,
+ )
+ predictor = _Predictor(
+ num_symbols,
+ encoding_dim,
+ symbol_embedding_dim=symbol_embedding_dim,
+ num_lstm_layers=num_lstm_layers,
+ lstm_hidden_dim=symbol_embedding_dim,
+ lstm_layer_norm=lstm_layer_norm,
+ lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
+ lstm_dropout=lstm_dropout,
+ )
+ joiner = _Joiner(encoding_dim, num_symbols)
+ return RNNT(encoder, predictor, joiner)
+
+
+def emformer_rnnt_base(num_symbols: int) -> RNNT:
+ r"""Builds basic version of Emformer-based :class:`~torchaudio.models.RNNT`.
+
+ Args:
+ num_symbols (int): The size of target token lexicon.
+
+ Returns:
+ RNNT:
+ Emformer RNN-T model.
+ """
+ return emformer_rnnt_model(
+ input_dim=80,
+ encoding_dim=1024,
+ num_symbols=num_symbols,
+ segment_length=16,
+ right_context_length=4,
+ time_reduction_input_dim=128,
+ time_reduction_stride=4,
+ transformer_num_heads=8,
+ transformer_ffn_dim=2048,
+ transformer_num_layers=20,
+ transformer_dropout=0.1,
+ transformer_activation="gelu",
+ transformer_left_context_length=30,
+ transformer_max_memory_size=0,
+ transformer_weight_init_scale_strategy="depthwise",
+ transformer_tanh_on_mem=True,
+ symbol_embedding_dim=512,
+ num_lstm_layers=3,
+ lstm_layer_norm=True,
+ lstm_layer_norm_epsilon=1e-3,
+ lstm_dropout=0.3,
+ )
diff --git a/MLPY/Lib/site-packages/torchaudio/models/rnnt_decoder.py b/MLPY/Lib/site-packages/torchaudio/models/rnnt_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fe03715513a077984f5c3d4fcf95e0fa653b5f7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/rnnt_decoder.py
@@ -0,0 +1,339 @@
+from typing import Callable, Dict, List, Optional, Tuple
+
+import torch
+from torchaudio.models import RNNT
+
+
+__all__ = ["Hypothesis", "RNNTBeamSearch"]
+
+
+Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float]
+Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder,
+ represented as tuple of (tokens, prediction network output, prediction network state, score).
+ """
+
+
+def _get_hypo_tokens(hypo: Hypothesis) -> List[int]:
+ return hypo[0]
+
+
+def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor:
+ return hypo[1]
+
+
+def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]:
+ return hypo[2]
+
+
+def _get_hypo_score(hypo: Hypothesis) -> float:
+ return hypo[3]
+
+
+def _get_hypo_key(hypo: Hypothesis) -> str:
+ return str(hypo[0])
+
+
+def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
+ states: List[List[torch.Tensor]] = []
+ for i in range(len(_get_hypo_state(hypos[0]))):
+ batched_state_components: List[torch.Tensor] = []
+ for j in range(len(_get_hypo_state(hypos[0])[i])):
+ batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos]))
+ states.append(batched_state_components)
+ return states
+
+
+def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]:
+ idx_tensor = torch.tensor([idx], device=device)
+ return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states]
+
+
+def _default_hypo_sort_key(hypo: Hypothesis) -> float:
+ return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1)
+
+
+def _compute_updated_scores(
+ hypos: List[Hypothesis],
+ next_token_probs: torch.Tensor,
+ beam_width: int,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1)
+ nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
+ nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
+ nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
+ nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
+ return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
+
+
+def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
+ for i, elem in enumerate(hypo_list):
+ if _get_hypo_key(hypo) == _get_hypo_key(elem):
+ del hypo_list[i]
+ break
+
+
+class RNNTBeamSearch(torch.nn.Module):
+ r"""Beam search decoder for RNN-T model.
+
+ See Also:
+ * :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pretrained model.
+
+ Args:
+ model (RNNT): RNN-T model to use.
+ blank (int): index of blank token in vocabulary.
+ temperature (float, optional): temperature to apply to joint network output.
+ Larger values yield more uniform samples. (Default: 1.0)
+ hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score
+ for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns
+ hypothesis score normalized by token sequence length. (Default: None)
+ step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100)
+ """
+
+ def __init__(
+ self,
+ model: RNNT,
+ blank: int,
+ temperature: float = 1.0,
+ hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None,
+ step_max_tokens: int = 100,
+ ) -> None:
+ super().__init__()
+ self.model = model
+ self.blank = blank
+ self.temperature = temperature
+
+ if hypo_sort_key is None:
+ self.hypo_sort_key = _default_hypo_sort_key
+ else:
+ self.hypo_sort_key = hypo_sort_key
+
+ self.step_max_tokens = step_max_tokens
+
+ def _init_b_hypos(self, device: torch.device) -> List[Hypothesis]:
+ token = self.blank
+ state = None
+
+ one_tensor = torch.tensor([1], device=device)
+ pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
+ init_hypo = (
+ [token],
+ pred_out[0].detach(),
+ pred_state,
+ 0.0,
+ )
+ return [init_hypo]
+
+ def _gen_next_token_probs(
+ self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
+ ) -> torch.Tensor:
+ one_tensor = torch.tensor([1], device=device)
+ predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0)
+ joined_out, _, _ = self.model.join(
+ enc_out,
+ one_tensor,
+ predictor_out,
+ torch.tensor([1] * len(hypos), device=device),
+ ) # [beam_width, 1, 1, num_tokens]
+ joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
+ return joined_out[:, 0, 0]
+
+ def _gen_b_hypos(
+ self,
+ b_hypos: List[Hypothesis],
+ a_hypos: List[Hypothesis],
+ next_token_probs: torch.Tensor,
+ key_to_b_hypo: Dict[str, Hypothesis],
+ ) -> List[Hypothesis]:
+ for i in range(len(a_hypos)):
+ h_a = a_hypos[i]
+ append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1]
+ if _get_hypo_key(h_a) in key_to_b_hypo:
+ h_b = key_to_b_hypo[_get_hypo_key(h_a)]
+ _remove_hypo(h_b, b_hypos)
+ score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score))
+ else:
+ score = float(append_blank_score)
+ h_b = (
+ _get_hypo_tokens(h_a),
+ _get_hypo_predictor_out(h_a),
+ _get_hypo_state(h_a),
+ score,
+ )
+ b_hypos.append(h_b)
+ key_to_b_hypo[_get_hypo_key(h_b)] = h_b
+ _, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort()
+ return [b_hypos[idx] for idx in sorted_idx]
+
+ def _gen_a_hypos(
+ self,
+ a_hypos: List[Hypothesis],
+ b_hypos: List[Hypothesis],
+ next_token_probs: torch.Tensor,
+ t: int,
+ beam_width: int,
+ device: torch.device,
+ ) -> List[Hypothesis]:
+ (
+ nonblank_nbest_scores,
+ nonblank_nbest_hypo_idx,
+ nonblank_nbest_token,
+ ) = _compute_updated_scores(a_hypos, next_token_probs, beam_width)
+
+ if len(b_hypos) < beam_width:
+ b_nbest_score = -float("inf")
+ else:
+ b_nbest_score = _get_hypo_score(b_hypos[-beam_width])
+
+ base_hypos: List[Hypothesis] = []
+ new_tokens: List[int] = []
+ new_scores: List[float] = []
+ for i in range(beam_width):
+ score = float(nonblank_nbest_scores[i])
+ if score > b_nbest_score:
+ a_hypo_idx = int(nonblank_nbest_hypo_idx[i])
+ base_hypos.append(a_hypos[a_hypo_idx])
+ new_tokens.append(int(nonblank_nbest_token[i]))
+ new_scores.append(score)
+
+ if base_hypos:
+ new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device)
+ else:
+ new_hypos: List[Hypothesis] = []
+
+ return new_hypos
+
+ def _gen_new_hypos(
+ self,
+ base_hypos: List[Hypothesis],
+ tokens: List[int],
+ scores: List[float],
+ t: int,
+ device: torch.device,
+ ) -> List[Hypothesis]:
+ tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
+ states = _batch_state(base_hypos)
+ pred_out, _, pred_states = self.model.predict(
+ tgt_tokens,
+ torch.tensor([1] * len(base_hypos), device=device),
+ states,
+ )
+ new_hypos: List[Hypothesis] = []
+ for i, h_a in enumerate(base_hypos):
+ new_tokens = _get_hypo_tokens(h_a) + [tokens[i]]
+ new_hypos.append((new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i]))
+ return new_hypos
+
+ def _search(
+ self,
+ enc_out: torch.Tensor,
+ hypo: Optional[List[Hypothesis]],
+ beam_width: int,
+ ) -> List[Hypothesis]:
+ n_time_steps = enc_out.shape[1]
+ device = enc_out.device
+
+ a_hypos: List[Hypothesis] = []
+ b_hypos = self._init_b_hypos(device) if hypo is None else hypo
+ for t in range(n_time_steps):
+ a_hypos = b_hypos
+ b_hypos = torch.jit.annotate(List[Hypothesis], [])
+ key_to_b_hypo: Dict[str, Hypothesis] = {}
+ symbols_current_t = 0
+
+ while a_hypos:
+ next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
+ next_token_probs = next_token_probs.cpu()
+ b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo)
+
+ if symbols_current_t == self.step_max_tokens:
+ break
+
+ a_hypos = self._gen_a_hypos(
+ a_hypos,
+ b_hypos,
+ next_token_probs,
+ t,
+ beam_width,
+ device,
+ )
+ if a_hypos:
+ symbols_current_t += 1
+
+ _, sorted_idx = torch.tensor([self.hypo_sort_key(hyp) for hyp in b_hypos]).topk(beam_width)
+ b_hypos = [b_hypos[idx] for idx in sorted_idx]
+
+ return b_hypos
+
+ def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int) -> List[Hypothesis]:
+ r"""Performs beam search for the given input sequence.
+
+ T: number of frames;
+ D: feature dimension of each frame.
+
+ Args:
+ input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
+ length (torch.Tensor): number of valid frames in input
+ sequence, with shape () or (1,).
+ beam_width (int): beam size to use during search.
+
+ Returns:
+ List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
+ """
+ if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
+ raise ValueError("input must be of shape (T, D) or (1, T, D)")
+ if input.dim() == 2:
+ input = input.unsqueeze(0)
+
+ if length.shape != () and length.shape != (1,):
+ raise ValueError("length must be of shape () or (1,)")
+ if length.dim() == 0:
+ length = length.unsqueeze(0)
+
+ enc_out, _ = self.model.transcribe(input, length)
+ return self._search(enc_out, None, beam_width)
+
+ @torch.jit.export
+ def infer(
+ self,
+ input: torch.Tensor,
+ length: torch.Tensor,
+ beam_width: int,
+ state: Optional[List[List[torch.Tensor]]] = None,
+ hypothesis: Optional[List[Hypothesis]] = None,
+ ) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
+ r"""Performs beam search for the given input sequence in streaming mode.
+
+ T: number of frames;
+ D: feature dimension of each frame.
+
+ Args:
+ input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
+ length (torch.Tensor): number of valid frames in input
+ sequence, with shape () or (1,).
+ beam_width (int): beam size to use during search.
+ state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
+ representing transcription network internal state generated in preceding
+ invocation. (Default: ``None``)
+ hypothesis (List[Hypothesis] or None): hypotheses from preceding invocation to seed
+ search with. (Default: ``None``)
+
+ Returns:
+ (List[Hypothesis], List[List[torch.Tensor]]):
+ List[Hypothesis]
+ top-``beam_width`` hypotheses found by beam search.
+ List[List[torch.Tensor]]
+ list of lists of tensors representing transcription network
+ internal state generated in current invocation.
+ """
+ if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
+ raise ValueError("input must be of shape (T, D) or (1, T, D)")
+ if input.dim() == 2:
+ input = input.unsqueeze(0)
+
+ if length.shape != () and length.shape != (1,):
+ raise ValueError("length must be of shape () or (1,)")
+ if length.dim() == 0:
+ length = length.unsqueeze(0)
+
+ enc_out, _, state = self.model.transcribe_streaming(input, length, state)
+ return self._search(enc_out, hypothesis, beam_width), state
diff --git a/MLPY/Lib/site-packages/torchaudio/models/squim/__init__.py b/MLPY/Lib/site-packages/torchaudio/models/squim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d52102f153c973cebd9215f27481a9ec1b415139
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/squim/__init__.py
@@ -0,0 +1,11 @@
+from .objective import squim_objective_base, squim_objective_model, SquimObjective
+from .subjective import squim_subjective_base, squim_subjective_model, SquimSubjective
+
+__all__ = [
+ "squim_objective_base",
+ "squim_objective_model",
+ "squim_subjective_base",
+ "squim_subjective_model",
+ "SquimObjective",
+ "SquimSubjective",
+]
diff --git a/MLPY/Lib/site-packages/torchaudio/models/squim/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/models/squim/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa892cf12ffe33a7598bc90621bab76341110022
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/models/squim/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/models/squim/__pycache__/objective.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/models/squim/__pycache__/objective.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1dbdc4448d1cd94385a510c2cd87ac21a5e0f8ba
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/models/squim/__pycache__/objective.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/models/squim/__pycache__/subjective.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/models/squim/__pycache__/subjective.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7128d74b4ac9e09bff4f5925d0eb46c12f5ff814
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/models/squim/__pycache__/subjective.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/models/squim/objective.py b/MLPY/Lib/site-packages/torchaudio/models/squim/objective.py
new file mode 100644
index 0000000000000000000000000000000000000000..83155e7f3fb8c1cc5592fc8d2ed75fe9e03cdb28
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/squim/objective.py
@@ -0,0 +1,326 @@
+import math
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def transform_wb_pesq_range(x: float) -> float:
+ """The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
+ for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
+ defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".
+
+ Args:
+ x (float): Narrow-band PESQ score.
+
+ Returns:
+ (float): Wide-band PESQ score.
+ """
+ return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
+
+
+PESQRange: Tuple[float, float] = (
+ 1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of
+ # the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
+ # We are using 1.0 as a reasonable approximation.
+ transform_wb_pesq_range(4.5),
+)
+
+
+class RangeSigmoid(nn.Module):
+ def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
+ super(RangeSigmoid, self).__init__()
+ assert isinstance(val_range, tuple) and len(val_range) == 2
+ self.val_range: Tuple[float, float] = val_range
+ self.sigmoid: nn.modules.Module = nn.Sigmoid()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0]
+ return out
+
+
+class Encoder(nn.Module):
+ """Encoder module that transform 1D waveform to 2D representations.
+
+ Args:
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
+ win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
+ """
+
+ def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
+ super(Encoder, self).__init__()
+
+ self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Apply waveforms to convolutional layer and ReLU layer.
+
+ Args:
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
+
+ Returns:
+ (torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
+ """
+ out = x.unsqueeze(dim=1)
+ out = F.relu(self.conv1d(out))
+ return out
+
+
+class SingleRNN(nn.Module):
+ def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None:
+ super(SingleRNN, self).__init__()
+
+ self.rnn_type = rnn_type
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+
+ self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
+ input_size,
+ hidden_size,
+ 1,
+ dropout=dropout,
+ batch_first=True,
+ bidirectional=True,
+ )
+
+ self.proj = nn.Linear(hidden_size * 2, input_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # input shape: batch, seq, dim
+ out, _ = self.rnn(x)
+ out = self.proj(out)
+ return out
+
+
+class DPRNN(nn.Module):
+ """*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.
+
+ Args:
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
+ hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
+ num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
+ rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
+ d_model (int, optional): The number of expected features in the input. (Default: 256)
+ chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
+ chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
+ """
+
+ def __init__(
+ self,
+ feat_dim: int = 64,
+ hidden_dim: int = 128,
+ num_blocks: int = 6,
+ rnn_type: str = "LSTM",
+ d_model: int = 256,
+ chunk_size: int = 100,
+ chunk_stride: int = 50,
+ ) -> None:
+ super(DPRNN, self).__init__()
+
+ self.num_blocks = num_blocks
+
+ self.row_rnn = nn.ModuleList([])
+ self.col_rnn = nn.ModuleList([])
+ self.row_norm = nn.ModuleList([])
+ self.col_norm = nn.ModuleList([])
+ for _ in range(num_blocks):
+ self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
+ self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
+ self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
+ self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
+ self.conv = nn.Sequential(
+ nn.Conv2d(feat_dim, d_model, 1),
+ nn.PReLU(),
+ )
+ self.chunk_size = chunk_size
+ self.chunk_stride = chunk_stride
+
+ def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
+ # input shape: (B, N, T)
+ seq_len = x.shape[-1]
+
+ rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
+ out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
+
+ return out, rest
+
+ def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
+ out, rest = self.pad_chunk(x)
+ batch_size, feat_dim, seq_len = out.shape
+
+ segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
+ segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
+ out = torch.cat([segments1, segments2], dim=3)
+ out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous()
+
+ return out, rest
+
+ def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
+ batch_size, dim, _, _ = x.shape
+ out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2)
+ out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :]
+ out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
+ out = out1 + out2
+ if rest > 0:
+ out = out[:, :, :-rest]
+ out = out.contiguous()
+ return out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x, rest = self.chunking(x)
+ batch_size, _, dim1, dim2 = x.shape
+ out = x
+ for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm):
+ row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous()
+ row_out = row_rnn(row_in)
+ row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
+ row_out = row_norm(row_out)
+ out = out + row_out
+
+ col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous()
+ col_out = col_rnn(col_in)
+ col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
+ col_out = col_norm(col_out)
+ out = out + col_out
+ out = self.conv(out)
+ out = self.merging(out, rest)
+ out = out.transpose(1, 2).contiguous()
+ return out
+
+
+class AutoPool(nn.Module):
+ def __init__(self, pool_dim: int = 1) -> None:
+ super(AutoPool, self).__init__()
+ self.pool_dim: int = pool_dim
+ self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
+ self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ weight = self.softmax(torch.mul(x, self.alpha))
+ out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
+ return out
+
+
+class SquimObjective(nn.Module):
+ """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
+ for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
+
+ Args:
+ encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
+ dprnn (torch.nn.Module): DPRNN module to model sequential feature.
+ branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
+ """
+
+ def __init__(
+ self,
+ encoder: nn.Module,
+ dprnn: nn.Module,
+ branches: nn.ModuleList,
+ ):
+ super(SquimObjective, self).__init__()
+ self.encoder = encoder
+ self.dprnn = dprnn
+ self.branches = branches
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ """
+ Args:
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
+
+ Returns:
+ List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
+ """
+ if x.ndim != 2:
+ raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.")
+ x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
+ out = self.encoder(x)
+ out = self.dprnn(out)
+ scores = []
+ for branch in self.branches:
+ scores.append(branch(out).squeeze(dim=1))
+ return scores
+
+
+def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
+ """Create branch module after DPRNN model for predicting metric score.
+
+ Args:
+ d_model (int): The number of expected features in the input.
+ nhead (int): Number of heads in the multi-head attention model.
+ metric (str): The metric name to predict.
+
+ Returns:
+ (nn.Module): Returned module to predict corresponding metric score.
+ """
+ layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True)
+ layer2 = AutoPool()
+ if metric == "stoi":
+ layer3 = nn.Sequential(
+ nn.Linear(d_model, d_model),
+ nn.PReLU(),
+ nn.Linear(d_model, 1),
+ RangeSigmoid(),
+ )
+ elif metric == "pesq":
+ layer3 = nn.Sequential(
+ nn.Linear(d_model, d_model),
+ nn.PReLU(),
+ nn.Linear(d_model, 1),
+ RangeSigmoid(val_range=PESQRange),
+ )
+ else:
+ layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1))
+ return nn.Sequential(layer1, layer2, layer3)
+
+
+def squim_objective_model(
+ feat_dim: int,
+ win_len: int,
+ d_model: int,
+ nhead: int,
+ hidden_dim: int,
+ num_blocks: int,
+ rnn_type: str,
+ chunk_size: int,
+ chunk_stride: Optional[int] = None,
+) -> SquimObjective:
+ """Build a custome :class:`torchaudio.prototype.models.SquimObjective` model.
+
+ Args:
+ feat_dim (int, optional): The feature dimension after Encoder module.
+ win_len (int): Kernel size in the Encoder module.
+ d_model (int): The number of expected features in the input.
+ nhead (int): Number of heads in the multi-head attention model.
+ hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
+ num_blocks (int): Number of DPRNN layers.
+ rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
+ chunk_size (int): Chunk size of input for DPRNN.
+ chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
+ """
+ if chunk_stride is None:
+ chunk_stride = chunk_size // 2
+ encoder = Encoder(feat_dim, win_len)
+ dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride)
+ branches = nn.ModuleList(
+ [
+ _create_branch(d_model, nhead, "stoi"),
+ _create_branch(d_model, nhead, "pesq"),
+ _create_branch(d_model, nhead, "sisdr"),
+ ]
+ )
+ return SquimObjective(encoder, dprnn, branches)
+
+
+def squim_objective_base() -> SquimObjective:
+ """Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
+ return squim_objective_model(
+ feat_dim=256,
+ win_len=64,
+ d_model=256,
+ nhead=4,
+ hidden_dim=256,
+ num_blocks=2,
+ rnn_type="LSTM",
+ chunk_size=71,
+ )
diff --git a/MLPY/Lib/site-packages/torchaudio/models/squim/subjective.py b/MLPY/Lib/site-packages/torchaudio/models/squim/subjective.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3cc8ba3fc60e73351049ec1317ef3ddb050f70e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/squim/subjective.py
@@ -0,0 +1,150 @@
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torchaudio
+
+
+class AttPool(nn.Module):
+ """Attention-Pooling module that estimates the attention score.
+
+ Args:
+ input_dim (int): Input feature dimension.
+ att_dim (int): Attention Tensor dimension.
+ """
+
+ def __init__(self, input_dim: int, att_dim: int):
+ super(AttPool, self).__init__()
+
+ self.linear1 = nn.Linear(input_dim, 1)
+ self.linear2 = nn.Linear(input_dim, att_dim)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Apply attention and pooling.
+
+ Args:
+ x (torch.Tensor): Input Tensor with dimensions `(batch, time, feature_dim)`.
+
+ Returns:
+ (torch.Tensor): Attention score with dimensions `(batch, att_dim)`.
+ """
+
+ att = self.linear1(x) # (batch, time, 1)
+ att = att.transpose(2, 1) # (batch, 1, time)
+ att = nn.functional.softmax(att, dim=2)
+ x = torch.matmul(att, x).squeeze(1) # (batch, input_dim)
+ x = self.linear2(x) # (batch, att_dim)
+ return x
+
+
+class Predictor(nn.Module):
+ """Prediction module that apply pooling and attention, then predict subjective metric scores.
+
+ Args:
+ input_dim (int): Input feature dimension.
+ att_dim (int): Attention Tensor dimension.
+ """
+
+ def __init__(self, input_dim: int, att_dim: int):
+ super(Predictor, self).__init__()
+ self.att_pool_layer = AttPool(input_dim, att_dim)
+ self.att_dim = att_dim
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Predict subjective evaluation metric score.
+
+ Args:
+ x (torch.Tensor): Input Tensor with dimensions `(batch, time, feature_dim)`.
+
+ Returns:
+ (torch.Tensor): Subjective metric score. Tensor with dimensions `(batch,)`.
+ """
+ x = self.att_pool_layer(x)
+ x = nn.functional.softmax(x, dim=1)
+ B = torch.linspace(0, 4, steps=self.att_dim, device=x.device)
+ x = (x * B).sum(dim=1)
+ return x
+
+
+class SquimSubjective(nn.Module):
+ """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **subjective** metric scores
+ for speech enhancement (e.g., Mean Opinion Score (MOS)). The model is adopted from *NORESQA-MOS*
+ :cite:`manocha2022speech` which predicts MOS scores given the input speech and a non-matching reference.
+
+ Args:
+ ssl_model (torch.nn.Module): The self-supervised learning model for feature extraction.
+ projector (torch.nn.Module): Projection layer that projects SSL feature to a lower dimension.
+ predictor (torch.nn.Module): Predict the subjective scores.
+ """
+
+ def __init__(self, ssl_model: nn.Module, projector: nn.Module, predictor: nn.Module):
+ super(SquimSubjective, self).__init__()
+ self.ssl_model = ssl_model
+ self.projector = projector
+ self.predictor = predictor
+
+ def _align_shapes(self, waveform: torch.Tensor, reference: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Cut or pad the reference Tensor to make it aligned with waveform Tensor.
+
+ Args:
+ waveform (torch.Tensor): Input waveform for evaluation. Tensor with dimensions `(batch, time)`.
+ reference (torch.Tensor): Non-matching clean reference. Tensor with dimensions `(batch, time_ref)`.
+
+ Returns:
+ (torch.Tensor, torch.Tensor): The aligned waveform and reference Tensors
+ with same dimensions `(batch, time)`.
+ """
+ T_waveform = waveform.shape[-1]
+ T_reference = reference.shape[-1]
+ if T_reference < T_waveform:
+ num_padding = T_waveform // T_reference + 1
+ reference = torch.cat([reference for _ in range(num_padding)], dim=1)
+ return waveform, reference[:, :T_waveform]
+
+ def forward(self, waveform: torch.Tensor, reference: torch.Tensor):
+ """Predict subjective evaluation metric score.
+
+ Args:
+ waveform (torch.Tensor): Input waveform for evaluation. Tensor with dimensions `(batch, time)`.
+ reference (torch.Tensor): Non-matching clean reference. Tensor with dimensions `(batch, time_ref)`.
+
+ Returns:
+ (torch.Tensor): Subjective metric score. Tensor with dimensions `(batch,)`.
+ """
+ waveform, reference = self._align_shapes(waveform, reference)
+ waveform = self.projector(self.ssl_model.extract_features(waveform)[0][-1])
+ reference = self.projector(self.ssl_model.extract_features(reference)[0][-1])
+ concat = torch.cat((reference, waveform), dim=2)
+ score_diff = self.predictor(concat) # Score difference compared to the reference
+ return 5 - score_diff
+
+
+def squim_subjective_model(
+ ssl_type: str,
+ feat_dim: int,
+ proj_dim: int,
+ att_dim: int,
+) -> SquimSubjective:
+ """Build a custome :class:`torchaudio.prototype.models.SquimSubjective` model.
+
+ Args:
+ ssl_type (str): Type of self-supervised learning (SSL) models.
+ Must be one of ["wav2vec2_base", "wav2vec2_large"].
+ feat_dim (int): Feature dimension of the SSL feature representation.
+ proj_dim (int): Output dimension of projection layer.
+ att_dim (int): Dimension of attention scores.
+ """
+ ssl_model = getattr(torchaudio.models, ssl_type)()
+ projector = nn.Linear(feat_dim, proj_dim)
+ predictor = Predictor(proj_dim * 2, att_dim)
+ return SquimSubjective(ssl_model, projector, predictor)
+
+
+def squim_subjective_base() -> SquimSubjective:
+ """Build :class:`torchaudio.prototype.models.SquimSubjective` model with default arguments."""
+ return squim_subjective_model(
+ ssl_type="wav2vec2_base",
+ feat_dim=768,
+ proj_dim=32,
+ att_dim=5,
+ )
diff --git a/MLPY/Lib/site-packages/torchaudio/models/tacotron2.py b/MLPY/Lib/site-packages/torchaudio/models/tacotron2.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad4f9b21a66e69c3e0fdb8bb80e80cbcbe2ef429
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/tacotron2.py
@@ -0,0 +1,1046 @@
+# *****************************************************************************
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of the NVIDIA CORPORATION nor the
+# names of its contributors may be used to endorse or promote products
+# derived from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
+# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+# *****************************************************************************
+
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+
+
+__all__ = [
+ "Tacotron2",
+]
+
+
+def _get_linear_layer(in_dim: int, out_dim: int, bias: bool = True, w_init_gain: str = "linear") -> torch.nn.Linear:
+ r"""Linear layer with xavier uniform initialization.
+
+ Args:
+ in_dim (int): Size of each input sample.
+ out_dim (int): Size of each output sample.
+ bias (bool, optional): If set to ``False``, the layer will not learn an additive bias. (Default: ``True``)
+ w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain``
+ for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``)
+
+ Returns:
+ (torch.nn.Linear): The corresponding linear layer.
+ """
+ linear = torch.nn.Linear(in_dim, out_dim, bias=bias)
+ torch.nn.init.xavier_uniform_(linear.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
+ return linear
+
+
+def _get_conv1d_layer(
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 1,
+ stride: int = 1,
+ padding: Optional[Union[str, int, Tuple[int]]] = None,
+ dilation: int = 1,
+ bias: bool = True,
+ w_init_gain: str = "linear",
+) -> torch.nn.Conv1d:
+ r"""1D convolution with xavier uniform initialization.
+
+ Args:
+ in_channels (int): Number of channels in the input image.
+ out_channels (int): Number of channels produced by the convolution.
+ kernel_size (int, optional): Number of channels in the input image. (Default: ``1``)
+ stride (int, optional): Number of channels in the input image. (Default: ``1``)
+ padding (str, int or tuple, optional): Padding added to both sides of the input.
+ (Default: dilation * (kernel_size - 1) / 2)
+ dilation (int, optional): Number of channels in the input image. (Default: ``1``)
+ w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain``
+ for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``)
+
+ Returns:
+ (torch.nn.Conv1d): The corresponding Conv1D layer.
+ """
+ if padding is None:
+ if kernel_size % 2 != 1:
+ raise ValueError("kernel_size must be odd")
+ padding = int(dilation * (kernel_size - 1) / 2)
+
+ conv1d = torch.nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+
+ torch.nn.init.xavier_uniform_(conv1d.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ return conv1d
+
+
+def _get_mask_from_lengths(lengths: Tensor) -> Tensor:
+ r"""Returns a binary mask based on ``lengths``. The ``i``-th row and ``j``-th column of the mask
+ is ``1`` if ``j`` is smaller than ``i``-th element of ``lengths.
+
+ Args:
+ lengths (Tensor): The length of each element in the batch, with shape (n_batch, ).
+
+ Returns:
+ mask (Tensor): The binary mask, with shape (n_batch, max of ``lengths``).
+ """
+ max_len = torch.max(lengths).item()
+ ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype)
+ mask = (ids < lengths.unsqueeze(1)).byte()
+ mask = torch.le(mask, 0)
+ return mask
+
+
+class _LocationLayer(nn.Module):
+ r"""Location layer used in the Attention model.
+
+ Args:
+ attention_n_filter (int): Number of filters for attention model.
+ attention_kernel_size (int): Kernel size for attention model.
+ attention_hidden_dim (int): Dimension of attention hidden representation.
+ """
+
+ def __init__(
+ self,
+ attention_n_filter: int,
+ attention_kernel_size: int,
+ attention_hidden_dim: int,
+ ):
+ super().__init__()
+ padding = int((attention_kernel_size - 1) / 2)
+ self.location_conv = _get_conv1d_layer(
+ 2,
+ attention_n_filter,
+ kernel_size=attention_kernel_size,
+ padding=padding,
+ bias=False,
+ stride=1,
+ dilation=1,
+ )
+ self.location_dense = _get_linear_layer(
+ attention_n_filter, attention_hidden_dim, bias=False, w_init_gain="tanh"
+ )
+
+ def forward(self, attention_weights_cat: Tensor) -> Tensor:
+ r"""Location layer used in the Attention model.
+
+ Args:
+ attention_weights_cat (Tensor): Cumulative and previous attention weights
+ with shape (n_batch, 2, max of ``text_lengths``).
+
+ Returns:
+ processed_attention (Tensor): Cumulative and previous attention weights
+ with shape (n_batch, ``attention_hidden_dim``).
+ """
+ # (n_batch, attention_n_filter, text_lengths.max())
+ processed_attention = self.location_conv(attention_weights_cat)
+ processed_attention = processed_attention.transpose(1, 2)
+ # (n_batch, text_lengths.max(), attention_hidden_dim)
+ processed_attention = self.location_dense(processed_attention)
+ return processed_attention
+
+
+class _Attention(nn.Module):
+ r"""Locally sensitive attention model.
+
+ Args:
+ attention_rnn_dim (int): Number of hidden units for RNN.
+ encoder_embedding_dim (int): Number of embedding dimensions in the Encoder.
+ attention_hidden_dim (int): Dimension of attention hidden representation.
+ attention_location_n_filter (int): Number of filters for Attention model.
+ attention_location_kernel_size (int): Kernel size for Attention model.
+ """
+
+ def __init__(
+ self,
+ attention_rnn_dim: int,
+ encoder_embedding_dim: int,
+ attention_hidden_dim: int,
+ attention_location_n_filter: int,
+ attention_location_kernel_size: int,
+ ) -> None:
+ super().__init__()
+ self.query_layer = _get_linear_layer(attention_rnn_dim, attention_hidden_dim, bias=False, w_init_gain="tanh")
+ self.memory_layer = _get_linear_layer(
+ encoder_embedding_dim, attention_hidden_dim, bias=False, w_init_gain="tanh"
+ )
+ self.v = _get_linear_layer(attention_hidden_dim, 1, bias=False)
+ self.location_layer = _LocationLayer(
+ attention_location_n_filter,
+ attention_location_kernel_size,
+ attention_hidden_dim,
+ )
+ self.score_mask_value = -float("inf")
+
+ def _get_alignment_energies(self, query: Tensor, processed_memory: Tensor, attention_weights_cat: Tensor) -> Tensor:
+ r"""Get the alignment vector.
+
+ Args:
+ query (Tensor): Decoder output with shape (n_batch, n_mels * n_frames_per_step).
+ processed_memory (Tensor): Processed Encoder outputs
+ with shape (n_batch, max of ``text_lengths``, attention_hidden_dim).
+ attention_weights_cat (Tensor): Cumulative and previous attention weights
+ with shape (n_batch, 2, max of ``text_lengths``).
+
+ Returns:
+ alignment (Tensor): attention weights, it is a tensor with shape (batch, max of ``text_lengths``).
+ """
+
+ processed_query = self.query_layer(query.unsqueeze(1))
+ processed_attention_weights = self.location_layer(attention_weights_cat)
+ energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_memory))
+
+ alignment = energies.squeeze(2)
+ return alignment
+
+ def forward(
+ self,
+ attention_hidden_state: Tensor,
+ memory: Tensor,
+ processed_memory: Tensor,
+ attention_weights_cat: Tensor,
+ mask: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ r"""Pass the input through the Attention model.
+
+ Args:
+ attention_hidden_state (Tensor): Attention rnn last output with shape (n_batch, ``attention_rnn_dim``).
+ memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
+ processed_memory (Tensor): Processed Encoder outputs
+ with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
+ attention_weights_cat (Tensor): Previous and cumulative attention weights
+ with shape (n_batch, current_num_frames * 2, max of ``text_lengths``).
+ mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames).
+
+ Returns:
+ attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
+ attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
+ """
+ alignment = self._get_alignment_energies(attention_hidden_state, processed_memory, attention_weights_cat)
+
+ alignment = alignment.masked_fill(mask, self.score_mask_value)
+
+ attention_weights = F.softmax(alignment, dim=1)
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
+ attention_context = attention_context.squeeze(1)
+
+ return attention_context, attention_weights
+
+
+class _Prenet(nn.Module):
+ r"""Prenet Module. It is consists of ``len(output_size)`` linear layers.
+
+ Args:
+ in_dim (int): The size of each input sample.
+ output_sizes (list): The output dimension of each linear layers.
+ """
+
+ def __init__(self, in_dim: int, out_sizes: List[int]) -> None:
+ super().__init__()
+ in_sizes = [in_dim] + out_sizes[:-1]
+ self.layers = nn.ModuleList(
+ [_get_linear_layer(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_sizes, out_sizes)]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ r"""Pass the input through Prenet.
+
+ Args:
+ x (Tensor): The input sequence to Prenet with shape (n_batch, in_dim).
+
+ Return:
+ x (Tensor): Tensor with shape (n_batch, sizes[-1])
+ """
+
+ for linear in self.layers:
+ x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
+ return x
+
+
+class _Postnet(nn.Module):
+ r"""Postnet Module.
+
+ Args:
+ n_mels (int): Number of mel bins.
+ postnet_embedding_dim (int): Postnet embedding dimension.
+ postnet_kernel_size (int): Postnet kernel size.
+ postnet_n_convolution (int): Number of postnet convolutions.
+ """
+
+ def __init__(
+ self,
+ n_mels: int,
+ postnet_embedding_dim: int,
+ postnet_kernel_size: int,
+ postnet_n_convolution: int,
+ ):
+ super().__init__()
+ self.convolutions = nn.ModuleList()
+
+ for i in range(postnet_n_convolution):
+ in_channels = n_mels if i == 0 else postnet_embedding_dim
+ out_channels = n_mels if i == (postnet_n_convolution - 1) else postnet_embedding_dim
+ init_gain = "linear" if i == (postnet_n_convolution - 1) else "tanh"
+ num_features = n_mels if i == (postnet_n_convolution - 1) else postnet_embedding_dim
+ self.convolutions.append(
+ nn.Sequential(
+ _get_conv1d_layer(
+ in_channels,
+ out_channels,
+ kernel_size=postnet_kernel_size,
+ stride=1,
+ padding=int((postnet_kernel_size - 1) / 2),
+ dilation=1,
+ w_init_gain=init_gain,
+ ),
+ nn.BatchNorm1d(num_features),
+ )
+ )
+
+ self.n_convs = len(self.convolutions)
+
+ def forward(self, x: Tensor) -> Tensor:
+ r"""Pass the input through Postnet.
+
+ Args:
+ x (Tensor): The input sequence with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
+
+ Return:
+ x (Tensor): Tensor with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
+ """
+
+ for i, conv in enumerate(self.convolutions):
+ if i < self.n_convs - 1:
+ x = F.dropout(torch.tanh(conv(x)), 0.5, training=self.training)
+ else:
+ x = F.dropout(conv(x), 0.5, training=self.training)
+
+ return x
+
+
+class _Encoder(nn.Module):
+ r"""Encoder Module.
+
+ Args:
+ encoder_embedding_dim (int): Number of embedding dimensions in the encoder.
+ encoder_n_convolution (int): Number of convolution layers in the encoder.
+ encoder_kernel_size (int): The kernel size in the encoder.
+
+ Examples
+ >>> encoder = _Encoder(3, 512, 5)
+ >>> input = torch.rand(10, 20, 30)
+ >>> output = encoder(input) # shape: (10, 30, 512)
+ """
+
+ def __init__(
+ self,
+ encoder_embedding_dim: int,
+ encoder_n_convolution: int,
+ encoder_kernel_size: int,
+ ) -> None:
+ super().__init__()
+
+ self.convolutions = nn.ModuleList()
+ for _ in range(encoder_n_convolution):
+ conv_layer = nn.Sequential(
+ _get_conv1d_layer(
+ encoder_embedding_dim,
+ encoder_embedding_dim,
+ kernel_size=encoder_kernel_size,
+ stride=1,
+ padding=int((encoder_kernel_size - 1) / 2),
+ dilation=1,
+ w_init_gain="relu",
+ ),
+ nn.BatchNorm1d(encoder_embedding_dim),
+ )
+ self.convolutions.append(conv_layer)
+
+ self.lstm = nn.LSTM(
+ encoder_embedding_dim,
+ int(encoder_embedding_dim / 2),
+ 1,
+ batch_first=True,
+ bidirectional=True,
+ )
+ self.lstm.flatten_parameters()
+
+ def forward(self, x: Tensor, input_lengths: Tensor) -> Tensor:
+ r"""Pass the input through the Encoder.
+
+ Args:
+ x (Tensor): The input sequences with shape (n_batch, encoder_embedding_dim, n_seq).
+ input_lengths (Tensor): The length of each input sequence with shape (n_batch, ).
+
+ Return:
+ x (Tensor): A tensor with shape (n_batch, n_seq, encoder_embedding_dim).
+ """
+
+ for conv in self.convolutions:
+ x = F.dropout(F.relu(conv(x)), 0.5, self.training)
+
+ x = x.transpose(1, 2)
+
+ input_lengths = input_lengths.cpu()
+ x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True)
+
+ outputs, _ = self.lstm(x)
+ outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
+
+ return outputs
+
+
+class _Decoder(nn.Module):
+ r"""Decoder with Attention model.
+
+ Args:
+ n_mels (int): number of mel bins
+ n_frames_per_step (int): number of frames processed per step, only 1 is supported
+ encoder_embedding_dim (int): the number of embedding dimensions in the encoder.
+ decoder_rnn_dim (int): number of units in decoder LSTM
+ decoder_max_step (int): maximum number of output mel spectrograms
+ decoder_dropout (float): dropout probability for decoder LSTM
+ decoder_early_stopping (bool): stop decoding when all samples are finished
+ attention_rnn_dim (int): number of units in attention LSTM
+ attention_hidden_dim (int): dimension of attention hidden representation
+ attention_location_n_filter (int): number of filters for attention model
+ attention_location_kernel_size (int): kernel size for attention model
+ attention_dropout (float): dropout probability for attention LSTM
+ prenet_dim (int): number of ReLU units in prenet layers
+ gate_threshold (float): probability threshold for stop token
+ """
+
+ def __init__(
+ self,
+ n_mels: int,
+ n_frames_per_step: int,
+ encoder_embedding_dim: int,
+ decoder_rnn_dim: int,
+ decoder_max_step: int,
+ decoder_dropout: float,
+ decoder_early_stopping: bool,
+ attention_rnn_dim: int,
+ attention_hidden_dim: int,
+ attention_location_n_filter: int,
+ attention_location_kernel_size: int,
+ attention_dropout: float,
+ prenet_dim: int,
+ gate_threshold: float,
+ ) -> None:
+
+ super().__init__()
+ self.n_mels = n_mels
+ self.n_frames_per_step = n_frames_per_step
+ self.encoder_embedding_dim = encoder_embedding_dim
+ self.attention_rnn_dim = attention_rnn_dim
+ self.decoder_rnn_dim = decoder_rnn_dim
+ self.prenet_dim = prenet_dim
+ self.decoder_max_step = decoder_max_step
+ self.gate_threshold = gate_threshold
+ self.attention_dropout = attention_dropout
+ self.decoder_dropout = decoder_dropout
+ self.decoder_early_stopping = decoder_early_stopping
+
+ self.prenet = _Prenet(n_mels * n_frames_per_step, [prenet_dim, prenet_dim])
+
+ self.attention_rnn = nn.LSTMCell(prenet_dim + encoder_embedding_dim, attention_rnn_dim)
+
+ self.attention_layer = _Attention(
+ attention_rnn_dim,
+ encoder_embedding_dim,
+ attention_hidden_dim,
+ attention_location_n_filter,
+ attention_location_kernel_size,
+ )
+
+ self.decoder_rnn = nn.LSTMCell(attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, True)
+
+ self.linear_projection = _get_linear_layer(decoder_rnn_dim + encoder_embedding_dim, n_mels * n_frames_per_step)
+
+ self.gate_layer = _get_linear_layer(
+ decoder_rnn_dim + encoder_embedding_dim, 1, bias=True, w_init_gain="sigmoid"
+ )
+
+ def _get_initial_frame(self, memory: Tensor) -> Tensor:
+ r"""Gets all zeros frames to use as the first decoder input.
+
+ Args:
+ memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
+
+ Returns:
+ decoder_input (Tensor): all zeros frames with shape
+ (n_batch, max of ``text_lengths``, ``n_mels * n_frames_per_step``).
+ """
+
+ n_batch = memory.size(0)
+ dtype = memory.dtype
+ device = memory.device
+ decoder_input = torch.zeros(n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device)
+ return decoder_input
+
+ def _initialize_decoder_states(
+ self, memory: Tensor
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
+ r"""Initializes attention rnn states, decoder rnn states, attention
+ weights, attention cumulative weights, attention context, stores memory
+ and stores processed memory.
+
+ Args:
+ memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
+
+ Returns:
+ attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
+ attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
+ decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
+ decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
+ attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
+ attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
+ attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
+ processed_memory (Tensor): Processed encoder outputs
+ with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
+ """
+ n_batch = memory.size(0)
+ max_time = memory.size(1)
+ dtype = memory.dtype
+ device = memory.device
+
+ attention_hidden = torch.zeros(n_batch, self.attention_rnn_dim, dtype=dtype, device=device)
+ attention_cell = torch.zeros(n_batch, self.attention_rnn_dim, dtype=dtype, device=device)
+
+ decoder_hidden = torch.zeros(n_batch, self.decoder_rnn_dim, dtype=dtype, device=device)
+ decoder_cell = torch.zeros(n_batch, self.decoder_rnn_dim, dtype=dtype, device=device)
+
+ attention_weights = torch.zeros(n_batch, max_time, dtype=dtype, device=device)
+ attention_weights_cum = torch.zeros(n_batch, max_time, dtype=dtype, device=device)
+ attention_context = torch.zeros(n_batch, self.encoder_embedding_dim, dtype=dtype, device=device)
+
+ processed_memory = self.attention_layer.memory_layer(memory)
+
+ return (
+ attention_hidden,
+ attention_cell,
+ decoder_hidden,
+ decoder_cell,
+ attention_weights,
+ attention_weights_cum,
+ attention_context,
+ processed_memory,
+ )
+
+ def _parse_decoder_inputs(self, decoder_inputs: Tensor) -> Tensor:
+ r"""Prepares decoder inputs.
+
+ Args:
+ decoder_inputs (Tensor): Inputs used for teacher-forced training, i.e. mel-specs,
+ with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``)
+
+ Returns:
+ inputs (Tensor): Processed decoder inputs with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``).
+ """
+ # (n_batch, n_mels, mel_specgram_lengths.max()) -> (n_batch, mel_specgram_lengths.max(), n_mels)
+ decoder_inputs = decoder_inputs.transpose(1, 2)
+ decoder_inputs = decoder_inputs.view(
+ decoder_inputs.size(0),
+ int(decoder_inputs.size(1) / self.n_frames_per_step),
+ -1,
+ )
+ # (n_batch, mel_specgram_lengths.max(), n_mels) -> (mel_specgram_lengths.max(), n_batch, n_mels)
+ decoder_inputs = decoder_inputs.transpose(0, 1)
+ return decoder_inputs
+
+ def _parse_decoder_outputs(
+ self, mel_specgram: Tensor, gate_outputs: Tensor, alignments: Tensor
+ ) -> Tuple[Tensor, Tensor, Tensor]:
+ r"""Prepares decoder outputs for output
+
+ Args:
+ mel_specgram (Tensor): mel spectrogram with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``)
+ gate_outputs (Tensor): predicted stop token with shape (max of ``mel_specgram_lengths``, n_batch)
+ alignments (Tensor): sequence of attention weights from the decoder
+ with shape (max of ``mel_specgram_lengths``, n_batch, max of ``text_lengths``)
+
+ Returns:
+ mel_specgram (Tensor): mel spectrogram with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``)
+ gate_outputs (Tensor): predicted stop token with shape (n_batch, max of ``mel_specgram_lengths``)
+ alignments (Tensor): sequence of attention weights from the decoder
+ with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``)
+ """
+ # (mel_specgram_lengths.max(), n_batch, text_lengths.max())
+ # -> (n_batch, mel_specgram_lengths.max(), text_lengths.max())
+ alignments = alignments.transpose(0, 1).contiguous()
+ # (mel_specgram_lengths.max(), n_batch) -> (n_batch, mel_specgram_lengths.max())
+ gate_outputs = gate_outputs.transpose(0, 1).contiguous()
+ # (mel_specgram_lengths.max(), n_batch, n_mels) -> (n_batch, mel_specgram_lengths.max(), n_mels)
+ mel_specgram = mel_specgram.transpose(0, 1).contiguous()
+ # decouple frames per step
+ shape = (mel_specgram.shape[0], -1, self.n_mels)
+ mel_specgram = mel_specgram.view(*shape)
+ # (n_batch, mel_specgram_lengths.max(), n_mels) -> (n_batch, n_mels, T_out)
+ mel_specgram = mel_specgram.transpose(1, 2)
+
+ return mel_specgram, gate_outputs, alignments
+
+ def decode(
+ self,
+ decoder_input: Tensor,
+ attention_hidden: Tensor,
+ attention_cell: Tensor,
+ decoder_hidden: Tensor,
+ decoder_cell: Tensor,
+ attention_weights: Tensor,
+ attention_weights_cum: Tensor,
+ attention_context: Tensor,
+ memory: Tensor,
+ processed_memory: Tensor,
+ mask: Tensor,
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
+ r"""Decoder step using stored states, attention and memory
+
+ Args:
+ decoder_input (Tensor): Output of the Prenet with shape (n_batch, ``prenet_dim``).
+ attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
+ attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
+ decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
+ decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
+ attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
+ attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
+ attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
+ memory (Tensor): Encoder output with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
+ processed_memory (Tensor): Processed Encoder outputs
+ with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
+ mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames).
+
+ Returns:
+ decoder_output: Predicted mel spectrogram for the current frame with shape (n_batch, ``n_mels``).
+ gate_prediction (Tensor): Prediction of the stop token with shape (n_batch, ``1``).
+ attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
+ attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
+ decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
+ decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
+ attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
+ attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
+ attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
+ """
+ cell_input = torch.cat((decoder_input, attention_context), -1)
+
+ attention_hidden, attention_cell = self.attention_rnn(cell_input, (attention_hidden, attention_cell))
+ attention_hidden = F.dropout(attention_hidden, self.attention_dropout, self.training)
+
+ attention_weights_cat = torch.cat((attention_weights.unsqueeze(1), attention_weights_cum.unsqueeze(1)), dim=1)
+ attention_context, attention_weights = self.attention_layer(
+ attention_hidden, memory, processed_memory, attention_weights_cat, mask
+ )
+
+ attention_weights_cum += attention_weights
+ decoder_input = torch.cat((attention_hidden, attention_context), -1)
+
+ decoder_hidden, decoder_cell = self.decoder_rnn(decoder_input, (decoder_hidden, decoder_cell))
+ decoder_hidden = F.dropout(decoder_hidden, self.decoder_dropout, self.training)
+
+ decoder_hidden_attention_context = torch.cat((decoder_hidden, attention_context), dim=1)
+ decoder_output = self.linear_projection(decoder_hidden_attention_context)
+
+ gate_prediction = self.gate_layer(decoder_hidden_attention_context)
+
+ return (
+ decoder_output,
+ gate_prediction,
+ attention_hidden,
+ attention_cell,
+ decoder_hidden,
+ decoder_cell,
+ attention_weights,
+ attention_weights_cum,
+ attention_context,
+ )
+
+ def forward(
+ self, memory: Tensor, mel_specgram_truth: Tensor, memory_lengths: Tensor
+ ) -> Tuple[Tensor, Tensor, Tensor]:
+ r"""Decoder forward pass for training.
+
+ Args:
+ memory (Tensor): Encoder outputs
+ with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
+ mel_specgram_truth (Tensor): Decoder ground-truth mel-specs for teacher forcing
+ with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
+ memory_lengths (Tensor): Encoder output lengths for attention masking
+ (the same as ``text_lengths``) with shape (n_batch, ).
+
+ Returns:
+ mel_specgram (Tensor): Predicted mel spectrogram
+ with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
+ gate_outputs (Tensor): Predicted stop token for each timestep
+ with shape (n_batch, max of ``mel_specgram_lengths``).
+ alignments (Tensor): Sequence of attention weights from the decoder
+ with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
+ """
+
+ decoder_input = self._get_initial_frame(memory).unsqueeze(0)
+ decoder_inputs = self._parse_decoder_inputs(mel_specgram_truth)
+ decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
+ decoder_inputs = self.prenet(decoder_inputs)
+
+ mask = _get_mask_from_lengths(memory_lengths)
+ (
+ attention_hidden,
+ attention_cell,
+ decoder_hidden,
+ decoder_cell,
+ attention_weights,
+ attention_weights_cum,
+ attention_context,
+ processed_memory,
+ ) = self._initialize_decoder_states(memory)
+
+ mel_outputs, gate_outputs, alignments = [], [], []
+ while len(mel_outputs) < decoder_inputs.size(0) - 1:
+ decoder_input = decoder_inputs[len(mel_outputs)]
+ (
+ mel_output,
+ gate_output,
+ attention_hidden,
+ attention_cell,
+ decoder_hidden,
+ decoder_cell,
+ attention_weights,
+ attention_weights_cum,
+ attention_context,
+ ) = self.decode(
+ decoder_input,
+ attention_hidden,
+ attention_cell,
+ decoder_hidden,
+ decoder_cell,
+ attention_weights,
+ attention_weights_cum,
+ attention_context,
+ memory,
+ processed_memory,
+ mask,
+ )
+
+ mel_outputs += [mel_output.squeeze(1)]
+ gate_outputs += [gate_output.squeeze(1)]
+ alignments += [attention_weights]
+
+ mel_specgram, gate_outputs, alignments = self._parse_decoder_outputs(
+ torch.stack(mel_outputs), torch.stack(gate_outputs), torch.stack(alignments)
+ )
+
+ return mel_specgram, gate_outputs, alignments
+
+ def _get_go_frame(self, memory: Tensor) -> Tensor:
+ """Gets all zeros frames to use as the first decoder input
+
+ args:
+ memory (Tensor): Encoder outputs
+ with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
+
+ returns:
+ decoder_input (Tensor): All zeros frames with shape(n_batch, ``n_mels`` * ``n_frame_per_step``).
+ """
+
+ n_batch = memory.size(0)
+ dtype = memory.dtype
+ device = memory.device
+ decoder_input = torch.zeros(n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device)
+ return decoder_input
+
+ @torch.jit.export
+ def infer(self, memory: Tensor, memory_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ """Decoder inference
+
+ Args:
+ memory (Tensor): Encoder outputs
+ with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
+ memory_lengths (Tensor): Encoder output lengths for attention masking
+ (the same as ``text_lengths``) with shape (n_batch, ).
+
+ Returns:
+ mel_specgram (Tensor): Predicted mel spectrogram
+ with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
+ mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, ))
+ gate_outputs (Tensor): Predicted stop token for each timestep
+ with shape (n_batch, max of ``mel_specgram_lengths``).
+ alignments (Tensor): Sequence of attention weights from the decoder
+ with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
+ """
+ batch_size, device = memory.size(0), memory.device
+
+ decoder_input = self._get_go_frame(memory)
+
+ mask = _get_mask_from_lengths(memory_lengths)
+ (
+ attention_hidden,
+ attention_cell,
+ decoder_hidden,
+ decoder_cell,
+ attention_weights,
+ attention_weights_cum,
+ attention_context,
+ processed_memory,
+ ) = self._initialize_decoder_states(memory)
+
+ mel_specgram_lengths = torch.zeros([batch_size], dtype=torch.int32, device=device)
+ finished = torch.zeros([batch_size], dtype=torch.bool, device=device)
+ mel_specgrams: List[Tensor] = []
+ gate_outputs: List[Tensor] = []
+ alignments: List[Tensor] = []
+ for _ in range(self.decoder_max_step):
+ decoder_input = self.prenet(decoder_input)
+ (
+ mel_specgram,
+ gate_output,
+ attention_hidden,
+ attention_cell,
+ decoder_hidden,
+ decoder_cell,
+ attention_weights,
+ attention_weights_cum,
+ attention_context,
+ ) = self.decode(
+ decoder_input,
+ attention_hidden,
+ attention_cell,
+ decoder_hidden,
+ decoder_cell,
+ attention_weights,
+ attention_weights_cum,
+ attention_context,
+ memory,
+ processed_memory,
+ mask,
+ )
+
+ mel_specgrams.append(mel_specgram.unsqueeze(0))
+ gate_outputs.append(gate_output.transpose(0, 1))
+ alignments.append(attention_weights)
+ mel_specgram_lengths[~finished] += 1
+
+ finished |= torch.sigmoid(gate_output.squeeze(1)) > self.gate_threshold
+ if self.decoder_early_stopping and torch.all(finished):
+ break
+
+ decoder_input = mel_specgram
+
+ if len(mel_specgrams) == self.decoder_max_step:
+ warnings.warn(
+ "Reached max decoder steps. The generated spectrogram might not cover " "the whole transcript."
+ )
+
+ mel_specgrams = torch.cat(mel_specgrams, dim=0)
+ gate_outputs = torch.cat(gate_outputs, dim=0)
+ alignments = torch.cat(alignments, dim=0)
+
+ mel_specgrams, gate_outputs, alignments = self._parse_decoder_outputs(mel_specgrams, gate_outputs, alignments)
+
+ return mel_specgrams, mel_specgram_lengths, gate_outputs, alignments
+
+
+class Tacotron2(nn.Module):
+ r"""Tacotron2 model from *Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
+ :cite:`shen2018natural` based on the implementation from
+ `Nvidia Deep Learning Examples `_.
+
+ See Also:
+ * :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model.
+
+ Args:
+ mask_padding (bool, optional): Use mask padding (Default: ``False``).
+ n_mels (int, optional): Number of mel bins (Default: ``80``).
+ n_symbol (int, optional): Number of symbols for the input text (Default: ``148``).
+ n_frames_per_step (int, optional): Number of frames processed per step, only 1 is supported (Default: ``1``).
+ symbol_embedding_dim (int, optional): Input embedding dimension (Default: ``512``).
+ encoder_n_convolution (int, optional): Number of encoder convolutions (Default: ``3``).
+ encoder_kernel_size (int, optional): Encoder kernel size (Default: ``5``).
+ encoder_embedding_dim (int, optional): Encoder embedding dimension (Default: ``512``).
+ decoder_rnn_dim (int, optional): Number of units in decoder LSTM (Default: ``1024``).
+ decoder_max_step (int, optional): Maximum number of output mel spectrograms (Default: ``2000``).
+ decoder_dropout (float, optional): Dropout probability for decoder LSTM (Default: ``0.1``).
+ decoder_early_stopping (bool, optional): Continue decoding after all samples are finished (Default: ``True``).
+ attention_rnn_dim (int, optional): Number of units in attention LSTM (Default: ``1024``).
+ attention_hidden_dim (int, optional): Dimension of attention hidden representation (Default: ``128``).
+ attention_location_n_filter (int, optional): Number of filters for attention model (Default: ``32``).
+ attention_location_kernel_size (int, optional): Kernel size for attention model (Default: ``31``).
+ attention_dropout (float, optional): Dropout probability for attention LSTM (Default: ``0.1``).
+ prenet_dim (int, optional): Number of ReLU units in prenet layers (Default: ``256``).
+ postnet_n_convolution (int, optional): Number of postnet convolutions (Default: ``5``).
+ postnet_kernel_size (int, optional): Postnet kernel size (Default: ``5``).
+ postnet_embedding_dim (int, optional): Postnet embedding dimension (Default: ``512``).
+ gate_threshold (float, optional): Probability threshold for stop token (Default: ``0.5``).
+ """
+
+ def __init__(
+ self,
+ mask_padding: bool = False,
+ n_mels: int = 80,
+ n_symbol: int = 148,
+ n_frames_per_step: int = 1,
+ symbol_embedding_dim: int = 512,
+ encoder_embedding_dim: int = 512,
+ encoder_n_convolution: int = 3,
+ encoder_kernel_size: int = 5,
+ decoder_rnn_dim: int = 1024,
+ decoder_max_step: int = 2000,
+ decoder_dropout: float = 0.1,
+ decoder_early_stopping: bool = True,
+ attention_rnn_dim: int = 1024,
+ attention_hidden_dim: int = 128,
+ attention_location_n_filter: int = 32,
+ attention_location_kernel_size: int = 31,
+ attention_dropout: float = 0.1,
+ prenet_dim: int = 256,
+ postnet_n_convolution: int = 5,
+ postnet_kernel_size: int = 5,
+ postnet_embedding_dim: int = 512,
+ gate_threshold: float = 0.5,
+ ) -> None:
+ super().__init__()
+
+ self.mask_padding = mask_padding
+ self.n_mels = n_mels
+ self.n_frames_per_step = n_frames_per_step
+ self.embedding = nn.Embedding(n_symbol, symbol_embedding_dim)
+ torch.nn.init.xavier_uniform_(self.embedding.weight)
+ self.encoder = _Encoder(encoder_embedding_dim, encoder_n_convolution, encoder_kernel_size)
+ self.decoder = _Decoder(
+ n_mels,
+ n_frames_per_step,
+ encoder_embedding_dim,
+ decoder_rnn_dim,
+ decoder_max_step,
+ decoder_dropout,
+ decoder_early_stopping,
+ attention_rnn_dim,
+ attention_hidden_dim,
+ attention_location_n_filter,
+ attention_location_kernel_size,
+ attention_dropout,
+ prenet_dim,
+ gate_threshold,
+ )
+ self.postnet = _Postnet(n_mels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolution)
+
+ def forward(
+ self,
+ tokens: Tensor,
+ token_lengths: Tensor,
+ mel_specgram: Tensor,
+ mel_specgram_lengths: Tensor,
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ r"""Pass the input through the Tacotron2 model. This is in teacher
+ forcing mode, which is generally used for training.
+
+ The input ``tokens`` should be padded with zeros to length max of ``token_lengths``.
+ The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``.
+
+ Args:
+ tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of token_lengths)`.
+ token_lengths (Tensor): The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
+ mel_specgram (Tensor): The target mel spectrogram
+ with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
+ mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`.
+
+ Returns:
+ [Tensor, Tensor, Tensor, Tensor]:
+ Tensor
+ Mel spectrogram before Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
+ Tensor
+ Mel spectrogram after Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
+ Tensor
+ The output for stop token at each time step with shape `(n_batch, max of mel_specgram_lengths)`.
+ Tensor
+ Sequence of attention weights from the decoder with
+ shape `(n_batch, max of mel_specgram_lengths, max of token_lengths)`.
+ """
+
+ embedded_inputs = self.embedding(tokens).transpose(1, 2)
+
+ encoder_outputs = self.encoder(embedded_inputs, token_lengths)
+ mel_specgram, gate_outputs, alignments = self.decoder(
+ encoder_outputs, mel_specgram, memory_lengths=token_lengths
+ )
+
+ mel_specgram_postnet = self.postnet(mel_specgram)
+ mel_specgram_postnet = mel_specgram + mel_specgram_postnet
+
+ if self.mask_padding:
+ mask = _get_mask_from_lengths(mel_specgram_lengths)
+ mask = mask.expand(self.n_mels, mask.size(0), mask.size(1))
+ mask = mask.permute(1, 0, 2)
+
+ mel_specgram.masked_fill_(mask, 0.0)
+ mel_specgram_postnet.masked_fill_(mask, 0.0)
+ gate_outputs.masked_fill_(mask[:, 0, :], 1e3)
+
+ return mel_specgram, mel_specgram_postnet, gate_outputs, alignments
+
+ @torch.jit.export
+ def infer(self, tokens: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
+ r"""Using Tacotron2 for inference. The input is a batch of encoded
+ sentences (``tokens``) and its corresponding lengths (``lengths``). The
+ output is the generated mel spectrograms, its corresponding lengths, and
+ the attention weights from the decoder.
+
+ The input `tokens` should be padded with zeros to length max of ``lengths``.
+
+ Args:
+ tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`.
+ lengths (Tensor or None, optional):
+ The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
+ If ``None``, it is assumed that the all the tokens are valid. Default: ``None``
+
+ Returns:
+ (Tensor, Tensor, Tensor):
+ Tensor
+ The predicted mel spectrogram with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
+ Tensor
+ The length of the predicted mel spectrogram with shape `(n_batch, )`.
+ Tensor
+ Sequence of attention weights from the decoder with shape
+ `(n_batch, max of mel_specgram_lengths, max of lengths)`.
+ """
+ n_batch, max_length = tokens.shape
+ if lengths is None:
+ lengths = torch.tensor([max_length]).expand(n_batch).to(tokens.device, tokens.dtype)
+
+ assert lengths is not None # For TorchScript compiler
+ embedded_inputs = self.embedding(tokens).transpose(1, 2)
+ encoder_outputs = self.encoder(embedded_inputs, lengths)
+ mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer(encoder_outputs, lengths)
+
+ mel_outputs_postnet = self.postnet(mel_specgram)
+ mel_outputs_postnet = mel_specgram + mel_outputs_postnet
+
+ alignments = alignments.unfold(1, n_batch, n_batch).transpose(0, 2)
+
+ return mel_outputs_postnet, mel_specgram_lengths, alignments
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wav2letter.py b/MLPY/Lib/site-packages/torchaudio/models/wav2letter.py
new file mode 100644
index 0000000000000000000000000000000000000000..defe7902fbe3c20aeb9fdaa5b5f32840691f6c66
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/wav2letter.py
@@ -0,0 +1,72 @@
+from torch import nn, Tensor
+
+__all__ = [
+ "Wav2Letter",
+]
+
+
+class Wav2Letter(nn.Module):
+ r"""Wav2Letter model architecture from *Wav2Letter: an End-to-End ConvNet-based Speech
+ Recognition System* :cite:`collobert2016wav2letter`.
+
+ See Also:
+ * `Training example `__
+
+ Args:
+ num_classes (int, optional): Number of classes to be classified. (Default: ``40``)
+ input_type (str, optional): Wav2Letter can use as input: ``waveform``, ``power_spectrum``
+ or ``mfcc`` (Default: ``waveform``).
+ num_features (int, optional): Number of input features that the network will receive (Default: ``1``).
+ """
+
+ def __init__(self, num_classes: int = 40, input_type: str = "waveform", num_features: int = 1) -> None:
+ super().__init__()
+
+ acoustic_num_features = 250 if input_type == "waveform" else num_features
+ acoustic_model = nn.Sequential(
+ nn.Conv1d(in_channels=acoustic_num_features, out_channels=250, kernel_size=48, stride=2, padding=23),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(in_channels=2000, out_channels=num_classes, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(inplace=True),
+ )
+
+ if input_type == "waveform":
+ waveform_model = nn.Sequential(
+ nn.Conv1d(in_channels=num_features, out_channels=250, kernel_size=250, stride=160, padding=45),
+ nn.ReLU(inplace=True),
+ )
+ self.acoustic_model = nn.Sequential(waveform_model, acoustic_model)
+
+ if input_type in ["power_spectrum", "mfcc"]:
+ self.acoustic_model = acoustic_model
+
+ def forward(self, x: Tensor) -> Tensor:
+ r"""
+ Args:
+ x (torch.Tensor): Tensor of dimension (batch_size, num_features, input_length).
+
+ Returns:
+ Tensor: Predictor tensor of dimension (batch_size, number_of_classes, input_length).
+ """
+
+ x = self.acoustic_model(x)
+ x = nn.functional.log_softmax(x, dim=1)
+ return x
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/__init__.py b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bd9720de106dd39f24f51ad267ec4b776cc3ab5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/__init__.py
@@ -0,0 +1,45 @@
+from . import utils
+from .model import (
+ hubert_base,
+ hubert_large,
+ hubert_pretrain_base,
+ hubert_pretrain_large,
+ hubert_pretrain_model,
+ hubert_pretrain_xlarge,
+ hubert_xlarge,
+ HuBERTPretrainModel,
+ wav2vec2_base,
+ wav2vec2_large,
+ wav2vec2_large_lv60k,
+ wav2vec2_model,
+ wav2vec2_xlsr_1b,
+ wav2vec2_xlsr_2b,
+ wav2vec2_xlsr_300m,
+ Wav2Vec2Model,
+ wavlm_base,
+ wavlm_large,
+ wavlm_model,
+)
+
+__all__ = [
+ "Wav2Vec2Model",
+ "HuBERTPretrainModel",
+ "wavlm_model",
+ "wavlm_base",
+ "wavlm_large",
+ "wav2vec2_model",
+ "wav2vec2_base",
+ "wav2vec2_large",
+ "wav2vec2_large_lv60k",
+ "hubert_base",
+ "hubert_large",
+ "hubert_xlarge",
+ "hubert_pretrain_model",
+ "hubert_pretrain_base",
+ "hubert_pretrain_large",
+ "hubert_pretrain_xlarge",
+ "utils",
+ "wav2vec2_xlsr_300m",
+ "wav2vec2_xlsr_1b",
+ "wav2vec2_xlsr_2b",
+]
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..182cab56a32f0db5fe22316c06e0fed463f41dac
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/__pycache__/components.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/__pycache__/components.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e6cec119a64aa2d30dfcbc4cd188541c3ae1e6d
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/__pycache__/components.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/__pycache__/model.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/__pycache__/model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4fe57823833f3f5bc72d0a80408ce6c49c1be8ac
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/__pycache__/model.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/__pycache__/wavlm_attention.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/__pycache__/wavlm_attention.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d59d898beae426c30997fc37007365281056bef9
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/__pycache__/wavlm_attention.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/components.py b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/components.py
new file mode 100644
index 0000000000000000000000000000000000000000..8489c5e1dd996d5e7f1a707b855ed87a6a7da99e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/components.py
@@ -0,0 +1,1167 @@
+import logging
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+from torch.nn import Module, Parameter
+
+from .wavlm_attention import WavLMSelfAttention
+
+_LG = logging.getLogger(__name__)
+
+
+def _init_transformer_params(module):
+ """
+ Initialize the weights of Transformer module in Wav2Vec2/HuBERT.
+
+ If the module is ``nn.Linear``, normalize the weight with mean 0 and standard deviation 0.02.
+ If ``bias`` is set to ``True`` in the module, set ``bias`` to 0.
+
+ If the module is ``nn.Embedding``, normalize the weight with mean 0 and standard deviation 0.02.
+ If ``padding_idx`` is not None, set the weight of padding to 0.
+
+ Note:
+ Ths method corresponds to
+ `init_bert_params
+ `__
+ in the original ``fairseq`` implementation.
+ """
+
+ def normal_(data):
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
+
+ if isinstance(module, nn.Linear):
+ normal_(module.weight.data)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ if isinstance(module, nn.Embedding):
+ normal_(module.weight.data)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class LayerNorm(nn.LayerNorm):
+ """Layer norm with transpose"""
+
+ def forward(self, input: Tensor) -> Tensor:
+ x = input.transpose(-2, -1)
+ x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ x = x.transpose(-2, -1)
+ return x
+
+
+class ConvLayerBlock(Module):
+ """Convolution unit of FeatureExtractor"""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int,
+ bias: bool,
+ layer_norm: Optional[Module],
+ ):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.layer_norm = layer_norm
+ self.conv = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ bias=bias,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ length: Optional[Tensor],
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ """
+ Args:
+ x (Tensor): Shape: ``[batch, in_channels, in_frame]``.
+ length (Tensor or None, optional): Shape ``[batch, ]``.
+ Returns:
+ Tensor: Shape ``[batch, out_channels, out_frames]``.
+ Optional[Tensor]: Shape ``[batch, ]``.
+ """
+ x = self.conv(x)
+ if self.layer_norm is not None:
+ x = self.layer_norm(x)
+ x = nn.functional.gelu(x)
+
+ if length is not None:
+ length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1
+ # When input length is 0, the resulting length can be negative. So fix it here.
+ length = torch.max(torch.zeros_like(length), length)
+ return x, length
+
+
+class FeatureExtractor(Module):
+ """Extract features from audio
+
+ Args:
+ conv_layers (nn.ModuleList):
+ convolution layers
+ """
+
+ def __init__(
+ self,
+ conv_layers: nn.ModuleList,
+ ):
+ super().__init__()
+ self.conv_layers = conv_layers
+
+ def forward(
+ self,
+ x: Tensor,
+ length: Optional[Tensor],
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ """
+ Args:
+ x (Tensor):
+ Input Tensor representing a batch of audio,
+ shape: ``[batch, time]``.
+ length (Tensor or None, optional):
+ Valid length of each input sample. shape: ``[batch, ]``.
+
+ Returns:
+ Tensor:
+ The resulting feature, shape: ``[batch, frame, feature]``
+ Optional[Tensor]:
+ Valid length of each output sample. shape: ``[batch, ]``.
+ """
+ if x.ndim != 2:
+ raise ValueError(f"Expected the input Tensor to be 2D (batch, time). Found: {list(x.shape)}")
+
+ x = x.unsqueeze(1) # (batch, channel==1, frame)
+ for layer in self.conv_layers:
+ x, length = layer(x, length) # (batch, feature, frame)
+ x = x.transpose(1, 2) # (batch, frame, feature)
+ return x, length
+
+
+class FeatureProjection(Module):
+ """Layer that connects FeatureExtractor and Encoder
+
+ Projects features to encoder dimension.
+
+ Args:
+ in_features (int): Input feature dim.
+ out_features (int): Output feature dim.
+ dropout (float): Dropout probability.
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ dropout: float,
+ ):
+ super().__init__()
+ self.layer_norm = nn.LayerNorm(in_features)
+ self.projection = nn.Linear(
+ in_features,
+ out_features,
+ )
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor):
+ Feature Tensor. shape: ``[batch, frame, in_feature]``
+ Returns:
+ Tensor: Projected features. ``[batch, frame, out_feature]``.
+ """
+ x = self.layer_norm(x)
+ x = self.projection(x)
+ x = self.dropout(x)
+ return x
+
+
+class ConvolutionalPositionalEmbedding(Module):
+ """Positional embedding which is placed at the beginning of Transformer.
+
+ Args:
+ embed_dim (int): Feature dimension of the input Tensor.
+ kernel_size (int): The number of frames to be use.
+ groups (int): The number of groups in feature dimensions.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ kernel_size: int,
+ groups: int,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.kernel_size = kernel_size
+ self.conv = nn.Conv1d(
+ in_channels=embed_dim,
+ out_channels=embed_dim,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ groups=groups,
+ )
+
+ self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
+ self.num_remove: int = 1 if kernel_size % 2 == 0 else 0
+
+ def __prepare_scriptable__(self):
+ if self.conv.__class__.__name__ == "ParametrizedConv1d":
+ _LG.warning("Removing weight_norm from %s", self.__class__.__name__)
+ torch.nn.utils.parametrize.remove_parametrizations(self.conv, "weight")
+ return self
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): shape ``[batch, frame, feature]``.
+
+ Returns:
+ Tensor: The resulting feature. Shape ``[batch, frame, feature]``.
+ """
+ x = x.transpose(-2, -1)
+ x = self.conv(x)
+ if self.num_remove > 0:
+ x = x[..., : -self.num_remove]
+ x = torch.nn.functional.gelu(x)
+ x = x.transpose(-2, -1)
+ return x
+
+
+class SelfAttention(Module):
+ """Multihead Self Attention module
+
+ Args:
+ embed_dim (int): Total dimension of the model.
+ num_heads (int): The number of heads.
+ dropout (float, optional):
+ Dropout probability on attn_output_weights. Default: ``0.0``
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+ head_dim = embed_dim // num_heads
+ if head_dim * num_heads != embed_dim:
+ raise ValueError(f"`embed_dim ({embed_dim})` is not divisible by `num_heads ({num_heads})`")
+
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = head_dim
+
+ self.scaling = self.head_dim**-0.5
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+
+ def forward(
+ self,
+ x: Tensor,
+ attention_mask: Optional[Tensor] = None,
+ position_bias: Optional[Tensor] = None,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ """
+ Args:
+ x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``.
+ attention_mask (Tensor or ``None``, optional):
+ shape: ``[batch_size, 1, sequence_length, sequence_length]``
+ position_bias: Not used. Only for the compatibility with :py:class:`WavLMSelfAttention`.
+ key_padding_mask (Tensor or ``None``): Not used. Only for the compatibility with
+ :py:class:`WavLMSelfAttention`.
+ Returns:
+ (Tensor, ``None``): The resulting attention output and ``None`` (necessary for compatibility
+ with :py:class:`WavLMSelAttention`).
+ Attention output shape: ``[batch, sequence_length, embed_dim]``.
+ """
+ if x.ndim != 3 or x.shape[2] != self.embed_dim:
+ raise ValueError(
+ f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found {x.shape}."
+ )
+ batch_size, length, embed_dim = x.size()
+ if attention_mask is not None:
+ shape_ = (batch_size, 1, length, length)
+ if attention_mask.size() != shape_:
+ raise ValueError(f"The expected attention mask shape is {shape_}. " f"Found {attention_mask.size()}.")
+
+ shape = (batch_size, length, self.num_heads, self.head_dim)
+ q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
+ k = self.k_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
+ v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
+ dropout = self.dropout if self.training else 0.0
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False
+ )
+ attn_output = attn_output.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
+ output = self.out_proj(attn_output)
+ return output, None # Necessary for compatibility with WavLMSelAttention
+
+
+class FeedForward(Module):
+ """Layer that follows attention layer in encoder layer."""
+
+ def __init__(
+ self,
+ io_features: int,
+ intermediate_features: int,
+ intermediate_dropout: float,
+ output_dropout: float,
+ ):
+ super().__init__()
+ self.intermediate_dense = nn.Linear(io_features, intermediate_features)
+ self.intermediate_dropout = nn.Dropout(intermediate_dropout)
+ self.output_dense = nn.Linear(intermediate_features, io_features)
+ self.output_dropout = nn.Dropout(output_dropout)
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): shape: `(batch, sequence_length, io_features)`
+ Returns:
+ x (Tensor): shape: `(batch, sequence_length, io_features)`
+ """
+ x = self.intermediate_dense(x)
+ x = torch.nn.functional.gelu(x)
+ x = self.intermediate_dropout(x)
+
+ x = self.output_dense(x)
+ x = self.output_dropout(x)
+ return x
+
+
+class EncoderLayer(Module):
+ """A layer unit in encoder. Combines multihead self attention and feed forward."""
+
+ def __init__(
+ self,
+ attention: Module,
+ dropout: float,
+ layer_norm_first: bool,
+ feed_forward: Module,
+ ):
+ super().__init__()
+ self.attention = attention
+ self.dropout = nn.Dropout(dropout)
+ self.layer_norm = nn.LayerNorm(attention.embed_dim)
+ self.layer_norm_first = layer_norm_first
+ self.feed_forward = feed_forward
+ self.final_layer_norm = nn.LayerNorm(attention.embed_dim)
+
+ def forward(
+ self,
+ x: Tensor,
+ attention_mask: Optional[Tensor] = None,
+ position_bias: Optional[Tensor] = None,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ """
+ Args:
+ x (Tensor): Input of shape ``(batch, sequence_length, embed_dim)``.
+ attention_mask (Tensor or ``None``, optional): attention mask
+ of shape ``(batch, 1, sequence_length, sequence_length)``. (Default: ``None``)
+ position_bias (Tensor or ``None``, optional): position bias of shape
+ ``(batch_size * num_heads, src_len, src_len)``.
+ Only necessary for WavLM model, ``None`` otherwise. (Default: ``None``)
+ key_padding_mask (Tensor or ``None``, optional): key padding mask of shape ``(batch_size, src_len)``.
+ Only used for WavLM model, ignored otherwise. (Default: ``None``)
+ Returns:
+ (x, position_bias): Shapes are the same as in the input. Position bias is only relevant for WaLM model,
+ ``None`` otherwise.
+ """
+ residual = x
+
+ if self.layer_norm_first:
+ x = self.layer_norm(x)
+
+ x, position_bias = self.attention(
+ x, attention_mask=attention_mask, position_bias=position_bias, key_padding_mask=key_padding_mask
+ )
+
+ x = self.dropout(x)
+ x = residual + x
+
+ if self.layer_norm_first:
+ x = x + self.feed_forward(self.final_layer_norm(x))
+ else:
+ x = self.layer_norm(x)
+ x = self.final_layer_norm(x + self.feed_forward(x))
+ return x, position_bias
+
+
+class Transformer(Module):
+ def __init__(
+ self,
+ pos_conv_embed: Module,
+ dropout: float,
+ layers: Module,
+ layer_norm_first: bool,
+ layer_drop: float,
+ ):
+ super().__init__()
+ self.pos_conv_embed = pos_conv_embed
+ self.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim)
+ self.layer_norm_first = layer_norm_first
+ self.layer_drop = layer_drop
+ self.dropout = nn.Dropout(dropout)
+ self.layers = layers
+
+ def _preprocess(self, x: Tensor):
+ x = x + self.pos_conv_embed(x)
+
+ if self.layer_norm_first:
+ x = self.layer_norm(x)
+
+ x = self.dropout(x)
+ return x
+
+ def forward(
+ self,
+ x: Tensor,
+ attention_mask: Optional[Tensor] = None,
+ position_bias: Optional[Tensor] = None,
+ ) -> Tensor:
+ x = self._preprocess(x)
+ for layer in self.layers:
+ if not (self.training and torch.rand(1).item() <= self.layer_drop):
+ x, position_bias = layer(x, attention_mask, position_bias=position_bias)
+
+ if not self.layer_norm_first:
+ x = self.layer_norm(x)
+ return x
+
+ def get_intermediate_outputs(
+ self,
+ x: Tensor,
+ attention_mask: Optional[Tensor] = None,
+ num_layers: Optional[int] = None,
+ ) -> List[Tensor]:
+ if num_layers is not None:
+ if not 0 < num_layers <= len(self.layers):
+ raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]")
+
+ ret: List[Tensor] = []
+ position_bias = None
+ x = self._preprocess(x)
+ for layer in self.layers:
+ x, position_bias = layer(x, attention_mask, position_bias=position_bias)
+ ret.append(x)
+ if num_layers is not None and len(ret) >= num_layers:
+ return ret
+ return ret
+
+
+class Encoder(Module):
+ def __init__(
+ self,
+ feature_projection: Module,
+ transformer: Module,
+ ):
+ super().__init__()
+ self.feature_projection = feature_projection
+ self.transformer = transformer
+
+ def _preprocess(
+ self,
+ features: Tensor,
+ lengths: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ x = self.feature_projection(features)
+
+ mask: Optional[Tensor] = None
+ if lengths is not None:
+ batch_size, max_len, _ = x.shape
+ # create mask for padded elements and zero-out them
+ mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]
+ x[mask] = 0.0
+ # extend the mask to attention shape and set weight
+ mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype)
+ mask = mask.expand(batch_size, 1, max_len, max_len)
+ return x, mask
+
+ def forward(
+ self,
+ features: Tensor,
+ lengths: Optional[Tensor] = None,
+ ) -> Tensor:
+ x, mask = self._preprocess(features, lengths)
+ x = self.transformer(x, attention_mask=mask)
+ return x
+
+ def extract_features(
+ self,
+ features: Tensor,
+ lengths: Optional[Tensor] = None,
+ num_layers: Optional[int] = None,
+ ) -> List[Tensor]:
+ x, masks = self._preprocess(features, lengths)
+ return self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers)
+
+
+################################################################################
+def _get_feature_extractor(
+ norm_mode: str,
+ shapes: List[Tuple[int, int, int]],
+ bias: bool,
+) -> FeatureExtractor:
+ """
+ Args:
+ norm_mode (str):
+ Either "group_norm" or "layer_norm".
+ If "group_norm", then a single normalization is applied
+ in the first convolution block. Otherwise, all the convolution
+ blocks will have layer normalization.
+ This option corresponds to "extractor_mode" from fairseq.
+ Expected values are "group_norm" for Base arch, and
+ "layer_norm" for Large arch.
+ shapes (list of tuple of int):
+ Configuration of convolution layers. List of convolution configuration,
+ i.e. ``[(output_channel, kernel_size, stride), ...]``
+ This option corresponds to "conv_feature_layers" from fairseq.
+ Expected values are
+ ``[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2``
+ for all the architectures.
+ bias (bool):
+ Whether to include bias term to each convolution operation.
+ This option corresponds to "conv_bias" from fairseq.
+ Expected values are False for Base arch, and True for Large arch.
+
+ See Also:
+ * Original implementation
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L666-L733
+ * "extractor_mode"
+ - Def and base:
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L38-L45
+ - Large:
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L52
+ * "conv_feature_layers"
+ - Def, base and large:
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L94-L100
+ * "conv_bias"
+ - Def and base:
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L101-L103
+ - Large:
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61
+ """
+ if norm_mode not in ["group_norm", "layer_norm"]:
+ raise ValueError("Invalid norm mode")
+ blocks = []
+ in_channels = 1
+ for i, (out_channels, kernel_size, stride) in enumerate(shapes):
+ normalization = None
+ if norm_mode == "group_norm" and i == 0:
+ normalization = nn.GroupNorm(
+ num_groups=out_channels,
+ num_channels=out_channels,
+ affine=True,
+ )
+ elif norm_mode == "layer_norm":
+ normalization = LayerNorm(
+ normalized_shape=out_channels,
+ elementwise_affine=True,
+ )
+ blocks.append(
+ ConvLayerBlock(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ bias=bias,
+ layer_norm=normalization,
+ )
+ )
+ in_channels = out_channels
+ return FeatureExtractor(nn.ModuleList(blocks))
+
+
+def _get_encoder(
+ in_features: int,
+ embed_dim: int,
+ dropout_input: float,
+ pos_conv_kernel: int,
+ pos_conv_groups: int,
+ num_layers: int,
+ num_heads: int,
+ attention_dropout: float,
+ ff_interm_features: int,
+ ff_interm_dropout: float,
+ dropout: float,
+ layer_norm_first: bool,
+ layer_drop: float,
+) -> Encoder:
+ """
+ Args:
+ in_features (int): The number of input features.
+ embed_dim (int):
+ The dimension of embedding.
+ This option corresponds to "encoder_embed_dim" from fairseq.
+ Expected values are 768 for Base arch, and 1024 for Large arch.
+ dropout_input (float):
+ The dropout probability applied after the input feature is projected
+ to ``embed_dim``.
+ This option corresponds to "dropout_input" from fairseq.
+ Expected values are 0.1 for both Base and Large arch.
+ pos_conv_kernel (int):
+ The kernel size of convolutional positional embeddings.
+ This option corresponds to "conv_pos" from fairseq.
+ Expected values are 128 for both Base and Large arch.
+ pos_conv_groups (int):
+ The number of groups of convolutional positional embeddings.
+ This option corresponds to "conv_pos_groups" from fairseq.
+ Expected values are 16 for both Base and Large arch.
+ num_layers (int):
+ The number of self attention layers in transformer block.
+ This option corresponds to "encoder_layers" from fairseq.
+ Expected values are 12 for Base and 24 for Large arch.
+ num_heads (int):
+ The number of heads in self attention layers.
+ This option corresponds to "encoder_attention_heads" from fairseq.
+ Expected values are 12 for Base and 16 for Large arch.
+ attention_dropout (float):
+ The dropout probability applied after softmax in self-attention layer.
+ This option corresponds to "attention_dropout" from fairseq.
+ Expected values are 0.1 for Base and 0.0 for Large arch.
+ ff_interm_features (int):
+ The dimension of hidden features in feed forward layer.
+ This option corresponds to "encoder_ffn_embed_dim" from fairseq.
+ Expected values are 3072 for Base and 4096 for Large arch.
+ ff_interm_dropout (float):
+ The dropout probability applied in feedforward layer.
+ This option correspinds to "activation_dropout" from fairseq.
+ Expected values are 0.1 for both Base and Large arch.
+ dropout (float):
+ The dropout probability applied at the end of feed forward layer.
+ This option corresponds to "dropout" from fairseq.
+ Expected values are 0.1 for Base and 0.0 for Large arch.
+ layer_norm_first (bool):
+ Control the order of layer norm in transformer layer and each encoder layer.
+ If True, in transformer layer, layer norm is applied before features are fed
+ to encoder layers. In encoder layer, two layer norms are applied before and after
+ self attention.
+ If False, in transformer layer, layer norm is applied after features are fed
+ to encoder layers. In encoder layer, two layer norms are applied after self
+ attention, before and after feed forward.
+ This option corresponds to "layer_norm_first" from fairseq.
+ Expected values are False for Base and True for Large arch.
+ layer_drop (float):
+ Probability to drop each encoder layer during training.
+ This option corresponds to "layerdrop" from fairseq.
+ Expected values are 0.1 for both Base and Large arch.
+
+ See Also:
+ * "encoder_embed_dim"
+ - Def and base
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L49-L51
+ - Large
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L64
+ * "dropout_input"
+ - Def, base and large
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L75-L78
+ * "conv_pos"
+ - Def, base and large
+ NOTE: The description is wrong.
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L204-L207
+ - Usage
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L756
+ * "conv_pos_groups"
+ - Def, base and large
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L208-L211
+ * "encoder_layers"
+ - Def and base
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L46-L48
+ - Large
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L63
+ * "encoder_attention_heads"
+ - Def and base
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L55-L57
+ - Large
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L66
+ * "attention_dropout"
+ - Def and base
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L66-L68
+ - Large
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L60
+ * "encoder_ffn_embed_dim"
+ - Def and base
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L52-L54
+ - Large
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L65
+ * "activation_dropout"
+ - Def
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L69-L71
+ - Base
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L55
+ - Large
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L55
+ * "dropout"
+ - Def and base
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L63-L65
+ - Large
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L59
+ * "layer_norm_first"
+ - Def and base
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L91-L93
+ - Large
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L53
+ * "layerdrop"
+ - Def
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L72-L74
+ - Base
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L54
+ - Large
+ https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L54
+ """
+ feature_projection = FeatureProjection(in_features, embed_dim, dropout_input)
+ pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups)
+
+ # Original impl
+ # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782
+ encoder_layers = nn.ModuleList()
+ for _ in range(num_layers):
+ attention = SelfAttention(
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ dropout=attention_dropout,
+ )
+ feed_forward = FeedForward(
+ io_features=embed_dim,
+ intermediate_features=ff_interm_features,
+ intermediate_dropout=ff_interm_dropout,
+ output_dropout=dropout,
+ )
+ encoder_layers.append(
+ EncoderLayer(
+ attention=attention,
+ dropout=dropout,
+ layer_norm_first=layer_norm_first,
+ feed_forward=feed_forward,
+ )
+ )
+ transformer = Transformer(
+ pos_conv_embed=pos_conv,
+ dropout=dropout,
+ layers=encoder_layers,
+ layer_norm_first=not layer_norm_first,
+ layer_drop=layer_drop,
+ )
+ return Encoder(feature_projection, transformer)
+
+
+def _get_wavlm_encoder(
+ in_features: int,
+ embed_dim: int,
+ dropout_input: float,
+ pos_conv_kernel: int,
+ pos_conv_groups: int,
+ num_layers: int,
+ num_heads: int,
+ num_buckets: int,
+ max_distance: int,
+ attention_dropout: float,
+ ff_interm_features: int,
+ ff_interm_dropout: float,
+ dropout: float,
+ layer_norm_first: bool,
+ layer_drop: float,
+) -> Encoder:
+ """
+ Construct encoder for WavLM model :cite:`chen2022wavlm`. The structure of the encoder and most of the argments are
+ the same as in :py:func:`_get_encoder` so refer there for documentation. The only difference from Wav2Vec2 encoder
+ is usage of `WavLMSelfAttention` instead of `SelfAttention` and two additional parameters: `num_buckets` and
+ `max_distance`.
+ Args:
+ in_features (int): See :py:func:`_get_encoder`.
+ embed_dim (int): See :py:func:`_get_encoder`.
+ dropout_input (float): See :py:func:`_get_encoder`.
+ pos_conv_kernel (int): See :py:func:`_get_encoder`.
+ pos_conv_groups (int): See :py:func:`_get_encoder`.
+ num_layers (int): See :py:func:`_get_encoder`.
+ num_heads (int): See :py:func:`_get_encoder`.
+ num_buckets (int): Number of buckets for relative position embedding.
+ max_distance (int): Maximum distance for relative position embedding.
+ attention_dropout (float): See :py:func:`_get_encoder`.
+ ff_interm_features (int): See :py:func:`_get_encoder`.
+ ff_interm_dropout (float): See :py:func:`_get_encoder`.
+ dropout (float): See :py:func:`_get_encoder`.
+ layer_norm_first (bool): See :py:func:`_get_encoder`.
+ layer_drop (float): See :py:func:`_get_encoder`.
+
+ """
+ feature_projection = FeatureProjection(in_features, embed_dim, dropout_input)
+ pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups)
+
+ # Original impl
+ # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782
+ encoder_layers = nn.ModuleList()
+ for i in range(num_layers):
+ attention = WavLMSelfAttention(
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ num_buckets=num_buckets,
+ max_distance=max_distance,
+ dropout=attention_dropout,
+ has_relative_attention_bias=(i == 0), # Position embedding is only necessary in the first layer.
+ )
+ feed_forward = FeedForward(
+ io_features=embed_dim,
+ intermediate_features=ff_interm_features,
+ intermediate_dropout=ff_interm_dropout,
+ output_dropout=dropout,
+ )
+ encoder_layers.append(
+ EncoderLayer(
+ attention=attention,
+ dropout=dropout,
+ layer_norm_first=layer_norm_first,
+ feed_forward=feed_forward,
+ )
+ )
+ transformer = Transformer(
+ pos_conv_embed=pos_conv,
+ dropout=dropout,
+ layers=encoder_layers,
+ layer_norm_first=not layer_norm_first,
+ layer_drop=layer_drop,
+ )
+ return Encoder(feature_projection, transformer)
+
+
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ padding_mask: Optional[Tensor],
+ mask_prob: float,
+ mask_length: int,
+ mask_type: str = "static",
+ mask_other: float = 0.0,
+ min_masks: int = 0,
+ no_overlap: bool = False,
+ min_space: int = 0,
+) -> Tensor:
+ """Computes random mask spans for a given shape.
+ Args:
+ shape (int, int): The shape for which to compute masks.
+ The first element is batch size and second is the number of frames.
+ padding_mask (Tensor or None): The padding mask of the same dimension as shape,
+ which will prevent masking padded elements.
+ mask_prob (float): Probability for each token to be chosen as start of the span to be masked.
+ This will be multiplied by number of timesteps divided by length of mask span to mask
+ approximately this percentage of all elements. However due to overlaps, the actual number
+ will be smaller (unless no_overlap is True).
+ mask_type (str): How to compute mask lengths. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
+ ``static``: Fixed size
+ ``uniform``: Sample from uniform distribution [mask_other, mask_length*2]
+ ``normal``: Sample from normal distribution with mean ``mask_length`` and stdev ``mask_other``.
+ ``poisson``: Sample from possion distribution with lambda = ``mask_length``.
+ min_masks (int): Minimum number of masked spans.
+ no_overlap (bool): If false, will switch to an alternative recursive algorithm
+ that prevents spans from overlapping.
+ min_space (int): How many frames to keep unmasked between spans (Only used if no_overlap is True).
+
+ Returns:
+ (Tensor): The mask indices of dimension `[batch, frame]`.
+ """
+
+ batch_size, frame = shape
+ mask = torch.full((batch_size, frame), False)
+ # add a random number for probabilistic rounding
+ all_num_mask = int(mask_prob * frame / float(mask_length) + torch.rand(1))
+
+ all_num_mask = max(min_masks, all_num_mask)
+
+ mask_idcs = []
+ for i in range(batch_size):
+ if padding_mask is not None:
+ sz = frame - padding_mask[i].long().sum().item()
+ # add a random number for probabilistic rounding
+ num_mask = int(mask_prob * sz / float(mask_length) + torch.rand(1))
+ num_mask = max(min_masks, num_mask)
+ else:
+ sz = frame
+ num_mask = all_num_mask
+
+ if mask_type == "static":
+ lengths = torch.full((num_mask,), mask_length)
+ elif mask_type == "uniform":
+ lengths = torch.randint(int(mask_other), mask_length * 2 + 1, size=(num_mask,))
+ elif mask_type == "normal":
+ lengths = torch.normal(mask_length, mask_other, size=(num_mask,))
+ lengths = torch.maximum(torch.ones(1), torch.round(lengths)).int()
+ elif mask_type == "poisson":
+ lengths = torch.poisson(mask_length, size=(num_mask,))
+ lengths = torch.round(lengths).int()
+ else:
+ raise Exception(f"unknown mask selection: {mask_type}")
+
+ if sum(lengths) == 0:
+ lengths[0] = min(mask_length, sz - 1)
+
+ if no_overlap:
+ mask_idc = []
+
+ def arrange(s, e, length, keep_length):
+ span_start = torch.randint(s, e - length, size=(1,))
+ mask_idc.extend(span_start + i for i in range(length))
+
+ new_parts = []
+ if span_start - s - min_space >= keep_length:
+ new_parts.append((s, span_start - min_space + 1))
+ if e - span_start - keep_length - min_space > keep_length:
+ new_parts.append((span_start + length + min_space, e))
+ return new_parts
+
+ parts = [(0, sz)]
+ min_length = min(lengths)
+ for length in sorted(lengths, reverse=True):
+ lens = torch.tensor([e - s for s, e in parts], dtype=torch.int)
+ lens[lens < length + min_space] = 0
+ l_sum = lens.sum()
+ if l_sum == 0:
+ break
+ probs = lens / l_sum
+ c = torch.distributions.categorical.Categorical(probs).sample()
+ s, e = parts.pop(c)
+ parts.extend(arrange(s, e, length, min_length))
+ mask_idc = torch.tensor(mask_idc)
+ else:
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+
+ mask_idc = torch.randperm(sz - min_len)[:num_mask]
+ mask_idc = torch.tensor(
+ [mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]
+ )
+
+ mask_idcs.append(torch.unique(mask_idc[mask_idc < sz]))
+
+ min_len = min([len(m) for m in mask_idcs])
+ for i, mask_idc in enumerate(mask_idcs):
+ if len(mask_idc) > min_len:
+ mask_idc = mask_idc[torch.randperm(len(mask_idc))[:min_len].long()]
+ mask[i, mask_idc] = True
+
+ return mask
+
+
+def _get_padding_mask(input: Tensor, lengths: Tensor) -> Tensor:
+ """Generate the padding mask given the padded input and the lengths Tensors.
+ Args:
+ input (Tensor): The padded Tensor of dimension `[batch, max_len, frequency]`.
+ lengths (Tensor): The lengths Tensor of dimension `[batch,]`.
+
+ Returns:
+ (Tensor): The padding mask.
+ """
+ batch_size, max_len, _ = input.shape
+ mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]
+ return mask
+
+
+class MaskGenerator(Module):
+ """Generate the masks for masked prediction.
+ Args:
+ encoder_embed_dim (int): The dimension of the transformer embedding output.
+ mask_prob (float): Probability for each token to be chosen as start of the span to be masked.
+ This will be multiplied by number of timesteps divided by length of mask span to mask
+ approximately this percentage of all elements. However due to overlaps, the actual number
+ will be smaller (unless no_overlap is True).
+ mask_selection (str): How to choose the mask length.
+ Options: [``static``, ``uniform``, ``normal``, ``poisson``].
+ mask_other (float): Secondary mask argument (used for more complex distributions).
+ mask_length (int): The lengths of the mask.
+ no_mask_overlap (bool): Whether to allow masks to overlap.
+ mask_min_space (int): Minimum space between spans (if no overlap is enabled).
+ mask_channel_prob (float): The probability of replacing a feature with 0.
+ mask_channel_selection (str): How to choose the mask length for channel masking.
+ Options: [``static``, ``uniform``, ``normal``, ``poisson``].
+ mask_channel_other (float): Secondary mask argument for channel masking(used for more complex distributions).
+ mask_channel_length (int): Minimum space between spans (if no overlap is enabled) for channel masking.
+ no_mask_channel_overlap (bool): Whether to allow channel masks to overlap.
+ mask_channel_min_space (int): Minimum space between spans for channel masking(if no overlap is enabled).
+ """
+
+ def __init__(
+ self,
+ encoder_embed_dim: int,
+ mask_prob: float,
+ mask_selection: str,
+ mask_other: float,
+ mask_length: int,
+ no_mask_overlap: bool,
+ mask_min_space: int,
+ mask_channel_prob: float,
+ mask_channel_selection: str,
+ mask_channel_other: float,
+ mask_channel_length: int,
+ no_mask_channel_overlap: bool,
+ mask_channel_min_space: int,
+ ):
+ super().__init__()
+ self.mask_prob = mask_prob
+ self.mask_selection = mask_selection
+ self.mask_other = mask_other
+ self.mask_length = mask_length
+ self.no_mask_overlap = no_mask_overlap
+ self.mask_min_space = mask_min_space
+ self.mask_channel_prob = mask_channel_prob
+ self.mask_channel_selection = mask_channel_selection
+ self.mask_channel_other = mask_channel_other
+ self.mask_channel_length = mask_channel_length
+ self.no_mask_channel_overlap = no_mask_channel_overlap
+ self.mask_channel_min_space = mask_channel_min_space
+ self.mask_embedding = Parameter(torch.FloatTensor(encoder_embed_dim))
+ torch.nn.init.uniform_(self.mask_embedding)
+
+ def forward(self, x: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
+ """
+ Args:
+ x (Tensor): The encoded representations after feature extraction module.
+ padding_mask (Tensor or None): The padding mask of the same dimension as shape,
+ which will prevent masking padded elements.
+
+ Returns:
+ Tensor: The feature representations after masking.
+ Tensor: The generated mask indices.
+ """
+ B, T, C = x.shape
+ if self.mask_prob > 0:
+ mask_indices = _compute_mask_indices(
+ (B, T),
+ padding_mask,
+ self.mask_prob,
+ self.mask_length,
+ self.mask_selection,
+ self.mask_other,
+ min_masks=2,
+ no_overlap=self.no_mask_overlap,
+ min_space=self.mask_min_space,
+ )
+ mask_indices = mask_indices.to(x.device)
+ # change dtype of mask_embedding to x for mixed-precision training.
+ # see https://github.com/pytorch/audio/issues/2847 for details.
+ x[mask_indices] = self.mask_embedding.to(x.dtype)
+ else:
+ mask_indices = None
+
+ if self.mask_channel_prob > 0:
+ mask_channel_indices = _compute_mask_indices(
+ (B, C),
+ None,
+ self.mask_channel_prob,
+ self.mask_channel_length,
+ self.mask_channel_selection,
+ self.mask_channel_other,
+ no_overlap=self.no_mask_channel_overlap,
+ min_space=self.mask_channel_min_space,
+ )
+ mask_channel_indices = mask_channel_indices.to(x.device).unsqueeze(1).expand(-1, T, -1)
+ x[mask_channel_indices] = 0
+
+ return x, mask_indices
+
+
+def _compute_logits(
+ proj_x: Tensor,
+ target: Tensor,
+ label_embeddings: Parameter,
+) -> Tensor:
+ """Compute the logits of the embeddings.
+ Args:
+ proj_x (Tensor): The projected masked representations of dimension `[batch, frame, final_dim]`.
+ target (Tensor): The target Tensor of dimension `[batch, frame, final_dim]`.
+ label_embeddings (Parameter): The trainable embeddings of target of dimension `[num_class, final_dim]`.
+
+ Returns:
+ (Tensor): The logits of the inputs.
+ """
+ logit_temp = 0.1
+ pos = torch.index_select(label_embeddings, 0, target.long())
+ negs = label_embeddings.unsqueeze(1).expand(-1, proj_x.size(0), -1)
+ neg_is_pos = (pos == negs).all(-1)
+ pos = pos.unsqueeze(0)
+ targets = torch.cat([pos, negs], dim=0)
+
+ logits = torch.cosine_similarity(proj_x.float(), targets.float(), dim=-1).type_as(proj_x)
+ logits /= logit_temp
+ if neg_is_pos.any():
+ logits[1:][neg_is_pos] = float("-inf")
+ logits = logits.transpose(0, 1) # (num_x, num_cls+1)
+ return logits
+
+
+class LogitGenerator(Module):
+ """Generate the logits of masked and unmasked inputs.
+ Args:
+ encoder_embed_dim (int): The dimension of the transformer embedding output.
+ num_classes (int): The number of classes in the labels.
+ final_dim (int): Project final representations and targets to `final_dim`.
+ skip_masked (bool): If True, skip computing losses over masked frames.
+ skip_nomask (bool): If True, skip computing losses over unmasked frames.
+ """
+
+ def __init__(
+ self,
+ encoder_embed_dim: int,
+ num_classes: int,
+ final_dim: int,
+ skip_masked: bool,
+ skip_nomask: bool,
+ ):
+ super().__init__()
+ self.label_embeddings = Parameter(torch.FloatTensor(num_classes, final_dim))
+ torch.nn.init.uniform_(self.label_embeddings)
+ self.final_proj = torch.nn.Linear(encoder_embed_dim, final_dim)
+ self.skip_masked = skip_masked
+ self.skip_nomask = skip_nomask
+
+ def forward(self, x: Tensor, label: Tensor, mask_m: Tensor, mask_u: Tensor) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ x (Tensor): The feature representation of the last transformer layer.
+ label (Tensor): The label Tensor of dimension `[batch, frame]`.
+ mask_m (Tensor): The masked indices of dimension `[batch, frame]`.
+ mask_u (Tensor): The unmasked indices of dimension `[batch, frame]`.
+
+ Returns:
+ Tensor: The logits of masked frames. Tensor of dimension `[masked_frame, final_dim]`.
+ Tensor: The logits of unmasked frames. Tensor of dimension `[unmasked_frame, final_dim]`.
+ """
+ proj_x = self.final_proj(x)
+ if self.skip_masked:
+ logit_m = None
+ else:
+ proj_x_m = proj_x[mask_m]
+ label_m = label[mask_m]
+ logit_m = _compute_logits(proj_x_m, label_m, self.label_embeddings)
+
+ if self.skip_nomask:
+ logit_u = None
+ else:
+ proj_x_u = proj_x[mask_u]
+ label_u = label[mask_u]
+ logit_u = _compute_logits(proj_x_u, label_u, self.label_embeddings)
+ return logit_m, logit_u
+
+
+class GradMultiply(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, scale):
+ ctx.scale = scale
+ res = x.new(x)
+ return res
+
+ @staticmethod
+ def backward(ctx, grad):
+ return grad * ctx.scale, None
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/model.py b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c4d7b1ad1651af8a3ec69215e9c885dbe240e75
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/model.py
@@ -0,0 +1,1579 @@
+import math
+from typing import List, Optional, Tuple
+
+import torch
+from torch import Tensor
+from torch.nn import Module
+
+from . import components
+
+
+class Wav2Vec2Model(Module):
+ """Acoustic model used in *wav2vec 2.0* :cite:`baevski2020wav2vec`.
+
+ Note:
+ To build the model, please use one of the factory functions.
+
+ See Also:
+ * :class:`torchaudio.pipelines.Wav2Vec2Bundle`: Pretrained models (without fine-tuning)
+ * :class:`torchaudio.pipelines.Wav2Vec2ASRBundle`: ASR pipelines with pretrained models.
+
+ Args:
+ feature_extractor (torch.nn.Module):
+ Feature extractor that extracts feature vectors from raw audio Tensor.
+
+ encoder (torch.nn.Module):
+ Encoder that converts the audio features into the sequence of probability
+ distribution (in negative log-likelihood) over labels.
+
+ aux (torch.nn.Module or None, optional):
+ Auxiliary module. If provided, the output from encoder is passed to this module.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ feature_extractor: Module,
+ encoder: Module,
+ aux: Optional[Module] = None,
+ ):
+ super().__init__()
+ self.feature_extractor = feature_extractor
+ self.encoder = encoder
+ self.aux = aux
+
+ @torch.jit.export
+ def extract_features(
+ self,
+ waveforms: Tensor,
+ lengths: Optional[Tensor] = None,
+ num_layers: Optional[int] = None,
+ ) -> Tuple[List[Tensor], Optional[Tensor]]:
+ """Extract feature vectors from raw waveforms
+
+ This returns the list of outputs from the intermediate layers of
+ transformer block in encoder.
+
+ Args:
+ waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
+ lengths (Tensor or None, optional):
+ Indicates the valid length of each audio in the batch.
+ Shape: `(batch, )`.
+ When the ``waveforms`` contains audios with different durations,
+ by providing ``lengths`` argument, the model will compute
+ the corresponding valid output lengths and apply proper mask in
+ transformer attention layer.
+ If ``None``, it is assumed that the entire audio waveform
+ length is valid.
+ num_layers (int or None, optional):
+ If given, limit the number of intermediate layers to go through.
+ Providing `1` will stop the computation after going through one
+ intermediate layers. If not given, the outputs from all the
+ intermediate layers are returned.
+
+ Returns:
+ (List[Tensor], Optional[Tensor]):
+ List of Tensors
+ Features from requested layers.
+ Each Tensor is of shape: `(batch, time frame, feature dimension)`
+ Tensor or None
+ If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
+ is returned.
+ It indicates the valid length in time axis of each feature Tensor.
+ """
+ x, lengths = self.feature_extractor(waveforms, lengths)
+ x = self.encoder.extract_features(x, lengths, num_layers)
+ return x, lengths
+
+ def forward(
+ self,
+ waveforms: Tensor,
+ lengths: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ """Compute the sequence of probability distribution over labels.
+
+ Args:
+ waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
+ lengths (Tensor or None, optional):
+ Indicates the valid length of each audio in the batch.
+ Shape: `(batch, )`.
+ When the ``waveforms`` contains audios with different durations,
+ by providing ``lengths`` argument, the model will compute
+ the corresponding valid output lengths and apply proper mask in
+ transformer attention layer.
+ If ``None``, it is assumed that all the audio in ``waveforms``
+ have valid length. Default: ``None``.
+
+ Returns:
+ (Tensor, Optional[Tensor]):
+ Tensor
+ The sequences of probability distribution (in logit) over labels.
+ Shape: `(batch, frames, num labels)`.
+ Tensor or None
+ If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
+ is returned.
+ It indicates the valid length in time axis of the output Tensor.
+ """
+ x, lengths = self.feature_extractor(waveforms, lengths)
+ x = self.encoder(x, lengths)
+ if self.aux is not None:
+ x = self.aux(x)
+ return x, lengths
+
+
+class HuBERTPretrainModel(Module):
+ """HuBERTPretrainModel()
+
+ HuBERT model used for pretraining in *HuBERT* :cite:`hsu2021hubert`.
+
+ Note:
+ To build the model, please use one of the factory functions.
+
+ See Also:
+ `HuBERT Pre-training and Fine-tuning Recipes
+ `__
+
+ Args:
+ wav2vec2 (Wav2Vec2Model):
+ Wav2Vec2 encoder that generates the transformer outputs.
+
+ mask_generator (torch.nn.Module):
+ Mask generator that generates the mask for masked prediction during the training.
+
+ logit_generator (torch.nn.Module):
+ Logit generator that predicts the logits of the masked and unmasked inputs.
+
+ feature_grad_mult (float or None):
+ The factor to scale the convolutional feature extraction layer gradients by.
+ If ``None``, the gradients of feature extraction layers are not affected.
+ The scale factor will not affect the forward pass.
+ """
+
+ def __init__(
+ self,
+ wav2vec2: Wav2Vec2Model,
+ mask_generator: Module,
+ logit_generator: Module,
+ feature_grad_mult: Optional[float],
+ ):
+ super().__init__()
+ self.wav2vec2 = wav2vec2
+ self.mask_generator = mask_generator
+ self.logit_generator = logit_generator
+ if feature_grad_mult is not None and not 0.0 < feature_grad_mult < 1.0:
+ raise ValueError(
+ f"The value of `feature_grad_mult` must be ``None``or between (0, 1). Found {feature_grad_mult}"
+ )
+ self.feature_grad_mult = feature_grad_mult
+
+ def forward(
+ self,
+ waveforms: Tensor,
+ labels: Tensor,
+ audio_lengths: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ """Compute the sequence of probability distribution over labels.
+
+ Args:
+ waveforms (Tensor): Audio tensor of dimension `[batch, frames]`.
+ labels (Tensor): Label for pre-training. A Tensor of dimension `[batch, frames]`.
+ audio_lengths (Tensor or None, optional):
+ Indicates the valid length of each audio in the batch.
+ Shape: `[batch, ]`.
+ When the ``waveforms`` contains audios with different durations,
+ by providing ``lengths`` argument, the model will compute
+ the corresponding valid output lengths and apply proper mask in
+ transformer attention layer.
+ If ``None``, it is assumed that all the audio in ``waveforms``
+ have valid length. Default: ``None``.
+
+ Returns:
+ (Tensor, Tensor, Tensor):
+ Tensor
+ The masked sequences of probability distribution (in logit).
+ Shape: `(masked_frames, num labels)`.
+ Tensor
+ The unmasked sequence of probability distribution (in logit).
+ Shape: `(unmasked_frames, num labels)`.
+ Tensor
+ The feature mean value for additional penalty loss.
+ Shape: `(1,)`.
+ """
+ x, lengths = self.wav2vec2.feature_extractor(waveforms, audio_lengths)
+ if self.feature_grad_mult is not None and self.feature_grad_mult < 1.0:
+ x = components.GradMultiply.apply(x, self.feature_grad_mult)
+ features_pen = x.float().pow(2).mean()
+ if lengths is not None:
+ padding_mask = components._get_padding_mask(x, lengths)
+ else:
+ padding_mask = None
+ x, attention_mask = self.wav2vec2.encoder._preprocess(x, lengths)
+ x, mask = self.mask_generator(x, padding_mask)
+ x = self.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
+ if x.shape[1] != labels.shape[1]:
+ raise ValueError("The length of label must match that of HuBERT model output")
+ if padding_mask is not None:
+ mask_m = torch.logical_and(~padding_mask, mask)
+ mask_u = torch.logical_and(~padding_mask, ~mask_m)
+ else:
+ mask_m = mask
+ mask_u = ~mask_m
+
+ logit_m, logit_u = self.logit_generator(x, labels, mask_m, mask_u)
+
+ return logit_m, logit_u, features_pen
+
+
+def wav2vec2_model(
+ extractor_mode: str,
+ extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
+ extractor_conv_bias: bool,
+ encoder_embed_dim: int,
+ encoder_projection_dropout: float,
+ encoder_pos_conv_kernel: int,
+ encoder_pos_conv_groups: int,
+ encoder_num_layers: int,
+ encoder_num_heads: int,
+ encoder_attention_dropout: float,
+ encoder_ff_interm_features: int,
+ encoder_ff_interm_dropout: float,
+ encoder_dropout: float,
+ encoder_layer_norm_first: bool,
+ encoder_layer_drop: float,
+ aux_num_out: Optional[int],
+) -> Wav2Vec2Model:
+ """Builds custom :class:`~torchaudio.models.Wav2Vec2Model`.
+
+ Note:
+ The "feature extractor" below corresponds to
+ `ConvFeatureExtractionModel `__
+ in the original ``fairseq`` implementation.
+ This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0*
+ :cite:`baevski2020wav2vec` paper.
+
+ The "encoder" below corresponds to `TransformerEncoder `__,
+ and this is referred as "Transformer" in the paper.
+
+ Args:
+ extractor_mode (str): Operation mode of feature extractor.
+ Valid values are ``"group_norm"`` or ``"layer_norm"``.
+ If ``"group_norm"``, then a single normalization is applied
+ in the first convolution block. Otherwise, all the convolution
+ blocks will have layer normalization.
+
+ This option corresponds to ``extractor_mode`` from ``fairseq``.
+ extractor_conv_layer_config (list of integer tuples or None):
+ Configuration of convolution layers in feature extractor.
+ List of convolution configuration,
+ i.e. ``[(output_channel, kernel_size, stride), ...]``
+
+ If ``None`` is provided, then the following default value is used.
+
+ .. code-block:: python
+
+ [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ]
+
+ This option corresponds to ``conv_feature_layers`` from ``fairseq``.
+
+ extractor_conv_bias (bool):
+ Whether to include bias term to each convolution operation.
+
+ This option corresponds to ``conv_bias`` from ``fairseq``.
+
+ encoder_embed_dim (int):
+ The dimension of embedding in encoder.
+
+ This option corresponds to ``encoder_embed_dim`` from ``fairseq``.
+
+ encoder_projection_dropout (float):
+ The dropout probability applied after the input feature is projected
+ to ``encoder_embed_dim``.
+
+ This option corresponds to ``dropout_input`` from ``fairseq``.
+
+ encoder_pos_conv_kernel (int):
+ The kernel size of convolutional positional embeddings.
+
+ This option corresponds to ``conv_pos`` from ``fairseq``.
+
+ encoder_pos_conv_groups (int):
+ The number of groups of convolutional positional embeddings.
+
+ This option corresponds to ``conv_pos_groups`` from ``fairseq``.
+
+ encoder_num_layers (int):
+ The number of self attention layers in transformer block.
+
+ This option corresponds to ``encoder_layers`` from ``fairseq``.
+
+ encoder_num_heads (int):
+ The number of heads in self attention layers.
+
+ This option corresponds to ``encoder_attention_heads`` from ``fairseq``.
+
+ encoder_attention_dropout (float):
+ The dropout probability applied after softmax in self-attention layer.
+
+ This option corresponds to ``attention_dropout`` from ``fairseq``.
+
+ encoder_ff_interm_features (int):
+ The dimension of hidden features in feed forward layer.
+
+ This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``.
+
+ encoder_ff_interm_dropout (float):
+ The dropout probability applied in feedforward layer.
+
+ This option correspinds to ``activation_dropout`` from ``fairseq``.
+
+ encoder_dropout (float):
+ The dropout probability applied at the end of feed forward layer.
+
+ This option corresponds to ``dropout`` from ``fairseq``.
+
+ encoder_layer_norm_first (bool):
+ Control the order of layer norm in transformer layer and each encoder layer.
+ If True, in transformer layer, layer norm is applied before features are fed
+ to encoder layers. In encoder layer, two layer norms are applied before and after
+ self attention.
+ If False, in transformer layer, layer norm is applied after features are fed
+ to encoder layers. In encoder layer, two layer norms are applied after self
+ attention, before and after feed forward.
+
+ This option corresponds to ``layer_norm_first`` from ``fairseq``.
+
+ encoder_layer_drop (float):
+ Probability to drop each encoder layer during training.
+
+ This option corresponds to ``layerdrop`` from ``fairseq``.
+
+ aux_num_out (int or None):
+ When provided, attach an extra linear layer on top of encoder, which can be
+ used for fine-tuning.
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting model.
+ """ # noqa: E501
+ if extractor_conv_layer_config is None:
+ extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
+
+ feature_extractor = components._get_feature_extractor(
+ extractor_mode, extractor_conv_layer_config, extractor_conv_bias
+ )
+ encoder = components._get_encoder(
+ in_features=extractor_conv_layer_config[-1][0],
+ embed_dim=encoder_embed_dim,
+ dropout_input=encoder_projection_dropout,
+ pos_conv_kernel=encoder_pos_conv_kernel,
+ pos_conv_groups=encoder_pos_conv_groups,
+ num_layers=encoder_num_layers,
+ num_heads=encoder_num_heads,
+ attention_dropout=encoder_attention_dropout,
+ ff_interm_features=encoder_ff_interm_features,
+ ff_interm_dropout=encoder_ff_interm_dropout,
+ dropout=encoder_dropout,
+ layer_norm_first=encoder_layer_norm_first,
+ layer_drop=encoder_layer_drop,
+ )
+ aux = None
+ if aux_num_out is not None:
+ aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out)
+ return Wav2Vec2Model(feature_extractor, encoder, aux)
+
+
+def wav2vec2_base(
+ encoder_projection_dropout: float = 0.1,
+ encoder_attention_dropout: float = 0.1,
+ encoder_ff_interm_dropout: float = 0.1,
+ encoder_dropout: float = 0.1,
+ encoder_layer_drop: float = 0.1,
+ aux_num_out: Optional[int] = None,
+) -> Wav2Vec2Model:
+ """Builds "base" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec`
+
+ Args:
+ encoder_projection_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_attention_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_ff_interm_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_layer_drop (float):
+ See :py:func:`wav2vec2_model`.
+ aux_num_out (int or None, optional):
+ See :py:func:`wav2vec2_model`.
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting model.
+ """ # noqa: E501
+ return wav2vec2_model(
+ extractor_mode="group_norm",
+ extractor_conv_layer_config=None,
+ extractor_conv_bias=False,
+ encoder_embed_dim=768,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_pos_conv_kernel=128,
+ encoder_pos_conv_groups=16,
+ encoder_num_layers=12,
+ encoder_num_heads=12,
+ encoder_attention_dropout=encoder_attention_dropout,
+ encoder_ff_interm_features=3072,
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
+ encoder_dropout=encoder_dropout,
+ encoder_layer_norm_first=False,
+ encoder_layer_drop=encoder_layer_drop,
+ aux_num_out=aux_num_out,
+ )
+
+
+def wav2vec2_large(
+ encoder_projection_dropout: float = 0.1,
+ encoder_attention_dropout: float = 0.1,
+ encoder_ff_interm_dropout: float = 0.1,
+ encoder_dropout: float = 0.1,
+ encoder_layer_drop: float = 0.1,
+ aux_num_out: Optional[int] = None,
+) -> Wav2Vec2Model:
+ """Builds "large" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec`
+
+ Args:
+ encoder_projection_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_attention_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_ff_interm_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_layer_drop (float):
+ See :py:func:`wav2vec2_model`.
+ aux_num_out (int or None, optional):
+ See :py:func:`wav2vec2_model`.
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting model.
+ """ # noqa: E501
+ return wav2vec2_model(
+ extractor_mode="group_norm",
+ extractor_conv_layer_config=None,
+ extractor_conv_bias=False,
+ encoder_embed_dim=1024,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_pos_conv_kernel=128,
+ encoder_pos_conv_groups=16,
+ encoder_num_layers=24,
+ encoder_num_heads=16,
+ encoder_attention_dropout=encoder_attention_dropout,
+ encoder_ff_interm_features=4096,
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
+ encoder_dropout=encoder_dropout,
+ encoder_layer_norm_first=False,
+ encoder_layer_drop=encoder_layer_drop,
+ aux_num_out=aux_num_out,
+ )
+
+
+def wav2vec2_large_lv60k(
+ encoder_projection_dropout: float = 0.1,
+ encoder_attention_dropout: float = 0.0,
+ encoder_ff_interm_dropout: float = 0.1,
+ encoder_dropout: float = 0.0,
+ encoder_layer_drop: float = 0.1,
+ aux_num_out: Optional[int] = None,
+) -> Wav2Vec2Model:
+ """Builds "large lv-60k" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec`
+
+ Args:
+ encoder_projection_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_attention_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_ff_interm_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_layer_drop (float):
+ See :py:func:`wav2vec2_model`.
+ aux_num_out (int or None, optional):
+ See :py:func:`wav2vec2_model`.
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting model.
+ """ # noqa: E501
+ return wav2vec2_model(
+ extractor_mode="layer_norm",
+ extractor_conv_layer_config=None,
+ extractor_conv_bias=True,
+ encoder_embed_dim=1024,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_pos_conv_kernel=128,
+ encoder_pos_conv_groups=16,
+ encoder_num_layers=24,
+ encoder_num_heads=16,
+ encoder_attention_dropout=encoder_attention_dropout,
+ encoder_ff_interm_features=4096,
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
+ encoder_dropout=encoder_dropout,
+ encoder_layer_norm_first=True,
+ encoder_layer_drop=encoder_layer_drop,
+ aux_num_out=aux_num_out,
+ )
+
+
+def hubert_base(
+ encoder_projection_dropout: float = 0.1,
+ encoder_attention_dropout: float = 0.1,
+ encoder_ff_interm_dropout: float = 0.0,
+ encoder_dropout: float = 0.1,
+ encoder_layer_drop: float = 0.05,
+ aux_num_out: Optional[int] = None,
+) -> Wav2Vec2Model:
+ """Builds "base" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert`
+
+ Args:
+ encoder_projection_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_attention_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_ff_interm_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_layer_drop (float):
+ See :py:func:`wav2vec2_model`.
+ aux_num_out (int or None, optional):
+ See :py:func:`wav2vec2_model`.
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting model.
+ """ # noqa: E501
+ return wav2vec2_model(
+ extractor_mode="group_norm",
+ extractor_conv_layer_config=None,
+ extractor_conv_bias=False,
+ encoder_embed_dim=768,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_pos_conv_kernel=128,
+ encoder_pos_conv_groups=16,
+ encoder_num_layers=12,
+ encoder_num_heads=12,
+ encoder_attention_dropout=encoder_attention_dropout,
+ encoder_ff_interm_features=3072,
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
+ encoder_dropout=encoder_dropout,
+ encoder_layer_norm_first=False,
+ encoder_layer_drop=encoder_layer_drop,
+ aux_num_out=aux_num_out,
+ )
+
+
+def hubert_large(
+ encoder_projection_dropout: float = 0.0,
+ encoder_attention_dropout: float = 0.0,
+ encoder_ff_interm_dropout: float = 0.0,
+ encoder_dropout: float = 0.0,
+ encoder_layer_drop: float = 0.0,
+ aux_num_out: Optional[int] = None,
+) -> Wav2Vec2Model:
+ """Builds "large" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert`
+
+ Args:
+ encoder_projection_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_attention_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_ff_interm_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_layer_drop (float):
+ See :py:func:`wav2vec2_model`.
+ aux_num_out (int or None, optional):
+ See :py:func:`wav2vec2_model`.
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting model.
+ """ # noqa: E501
+ return wav2vec2_model(
+ extractor_mode="layer_norm",
+ extractor_conv_layer_config=None,
+ extractor_conv_bias=False,
+ encoder_embed_dim=1024,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_pos_conv_kernel=128,
+ encoder_pos_conv_groups=16,
+ encoder_num_layers=24,
+ encoder_num_heads=16,
+ encoder_attention_dropout=encoder_attention_dropout,
+ encoder_ff_interm_features=4096,
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
+ encoder_dropout=encoder_dropout,
+ encoder_layer_norm_first=True,
+ encoder_layer_drop=encoder_layer_drop,
+ aux_num_out=aux_num_out,
+ )
+
+
+def hubert_xlarge(
+ encoder_projection_dropout: float = 0.0,
+ encoder_attention_dropout: float = 0.0,
+ encoder_ff_interm_dropout: float = 0.0,
+ encoder_dropout: float = 0.0,
+ encoder_layer_drop: float = 0.0,
+ aux_num_out: Optional[int] = None,
+) -> Wav2Vec2Model:
+ """Builds "extra large" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert`
+
+ Args:
+ encoder_projection_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_attention_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_ff_interm_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_dropout (float):
+ See :py:func:`wav2vec2_model`.
+ encoder_layer_drop (float):
+ See :py:func:`wav2vec2_model`.
+ aux_num_out (int or None, optional):
+ See :py:func:`wav2vec2_model`.
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting model.
+ """ # noqa: E501
+ return wav2vec2_model(
+ extractor_mode="layer_norm",
+ extractor_conv_layer_config=None,
+ extractor_conv_bias=False,
+ encoder_embed_dim=1280,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_pos_conv_kernel=128,
+ encoder_pos_conv_groups=16,
+ encoder_num_layers=48,
+ encoder_num_heads=16,
+ encoder_attention_dropout=encoder_attention_dropout,
+ encoder_ff_interm_features=5120,
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
+ encoder_dropout=encoder_dropout,
+ encoder_layer_norm_first=True,
+ encoder_layer_drop=encoder_layer_drop,
+ aux_num_out=aux_num_out,
+ )
+
+
+def _init_hubert_pretrain_model(module):
+ if isinstance(module, components.ConvLayerBlock):
+ torch.nn.init.kaiming_normal_(module.conv.weight)
+ elif isinstance(module, components.ConvolutionalPositionalEmbedding):
+ # normalize the weight to normal distribution.
+ std = math.sqrt(4.0 / (module.embed_dim * module.kernel_size))
+ torch.nn.init.normal_(module.conv.weight, mean=0.0, std=std)
+ torch.nn.init.constant_(module.conv.bias, 0.0)
+ elif isinstance(module, components.SelfAttention):
+ # normalize the query, key, value, and out_proj parameters in self attention module.
+ torch.nn.init.xavier_uniform_(module.k_proj.weight, gain=1 / math.sqrt(2))
+ torch.nn.init.xavier_uniform_(module.v_proj.weight, gain=1 / math.sqrt(2))
+ torch.nn.init.xavier_uniform_(module.q_proj.weight, gain=1 / math.sqrt(2))
+ torch.nn.init.xavier_uniform_(module.out_proj.weight)
+ torch.nn.init.constant_(module.out_proj.bias, 0.0)
+ elif isinstance(module, components.Transformer):
+ module.apply(components._init_transformer_params)
+ else:
+ pass
+
+
+def hubert_pretrain_model(
+ extractor_mode: str,
+ extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
+ extractor_conv_bias: bool,
+ encoder_embed_dim: int,
+ encoder_projection_dropout: float,
+ encoder_pos_conv_kernel: int,
+ encoder_pos_conv_groups: int,
+ encoder_num_layers: int,
+ encoder_num_heads: int,
+ encoder_attention_dropout: float,
+ encoder_ff_interm_features: int,
+ encoder_ff_interm_dropout: float,
+ encoder_dropout: float,
+ encoder_layer_norm_first: bool,
+ encoder_layer_drop: float,
+ mask_prob: float,
+ mask_selection: str,
+ mask_other: float,
+ mask_length: int,
+ no_mask_overlap: bool,
+ mask_min_space: int,
+ mask_channel_prob: float,
+ mask_channel_selection: str,
+ mask_channel_other: float,
+ mask_channel_length: int,
+ no_mask_channel_overlap: bool,
+ mask_channel_min_space: int,
+ skip_masked: bool,
+ skip_nomask: bool,
+ num_classes: int,
+ final_dim: int,
+ feature_grad_mult: Optional[float],
+) -> HuBERTPretrainModel:
+ """Builds custom :class:`HuBERTPretrainModel` for training from scratch
+
+ Note:
+ The "feature extractor" below corresponds to
+ `ConvFeatureExtractionModel `__
+ in the original ``fairseq`` implementation.
+ This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0*
+ :cite:`baevski2020wav2vec` paper.
+
+ The "encoder" below corresponds to `TransformerEncoder `__,
+ and this is referred as "Transformer" in the paper.
+
+ Args:
+ extractor_mode (str): Operation mode of feature extractor.
+ Valid values are ``"group_norm"`` or ``"layer_norm"``.
+ If ``"group_norm"``, then a single normalization is applied
+ in the first convolution block. Otherwise, all the convolution
+ blocks will have layer normalization.
+
+ This option corresponds to ``extractor_mode`` from ``fairseq``.
+
+ extractor_conv_layer_config (list of integer tuples or None):
+ Configuration of convolution layers in feature extractor.
+ List of convolution configuration,
+ i.e. ``[(output_channel, kernel_size, stride), ...]``
+
+ If ``None`` is provided, then the following default value is used.
+
+ .. code-block:: python
+
+ [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ]
+
+ This option corresponds to ``conv_feature_layers`` from ``fairseq``.
+
+ extractor_conv_bias (bool):
+ Whether to include bias term to each convolution operation.
+
+ This option corresponds to ``conv_bias`` from ``fairseq``.
+
+ encoder_embed_dim (int):
+ The dimension of embedding in encoder.
+
+ This option corresponds to ``encoder_embed_dim`` from ``fairseq``.
+
+ encoder_projection_dropout (float):
+ The dropout probability applied after the input feature is projected
+ to ``encoder_embed_dim``.
+
+ This option corresponds to ``dropout_input`` from ``fairseq``.
+
+ encoder_pos_conv_kernel (int):
+ The kernel size of convolutional positional embeddings.
+
+ This option corresponds to ``conv_pos`` from ``fairseq``.
+
+ encoder_pos_conv_groups (int):
+ The number of groups of convolutional positional embeddings.
+
+ This option corresponds to ``conv_pos_groups`` from ``fairseq``.
+
+ encoder_num_layers (int):
+ The number of self attention layers in transformer block.
+
+ This option corresponds to ``encoder_layers`` from ``fairseq``.
+
+ encoder_num_heads (int):
+ The number of heads in self attention layers.
+
+ This option corresponds to ``encoder_attention_heads`` from ``fairseq``.
+
+ encoder_attention_dropout (float):
+ The dropout probability applied after softmax in self-attention layer.
+
+ This option corresponds to ``attention_dropout`` from ``fairseq``.
+
+ encoder_ff_interm_features (int):
+ The dimension of hidden features in feed forward layer.
+
+ This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``.
+
+ encoder_ff_interm_dropout (float):
+ The dropout probability applied in feedforward layer.
+
+ This option correspinds to ``activation_dropout`` from ``fairseq``.
+
+ encoder_dropout (float):
+ The dropout probability applied at the end of feed forward layer.
+
+ This option corresponds to ``dropout`` from ``fairseq``.
+
+ encoder_layer_norm_first (bool):
+ Control the order of layer norm in transformer layer and each encoder layer.
+ If True, in transformer layer, layer norm is applied before features are fed
+ to encoder layers. In encoder layer, two layer norms are applied before and after
+ self attention.
+ If False, in transformer layer, layer norm is applied after features are fed
+ to encoder layers. In encoder layer, two layer norms are applied after self
+ attention, before and after feed forward.
+
+ This option corresponds to ``layer_norm_first`` from ``fairseq``.
+
+ encoder_layer_drop (float):
+ Probability to drop each encoder layer during training.
+
+ This option corresponds to ``layerdrop`` from ``fairseq``.
+
+ mask_prob (float):
+ Probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+ However due to overlaps, the actual number will be smaller (unless no_overlap is True).
+
+ This option corresponds to ``mask_prob`` from ``fairseq``.
+
+ mask_selection (str):
+ How to choose the mask length. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
+
+ This option corresponds to ``mask_selection`` from ``fairseq``.
+
+ mask_other (float):
+ Secondary mask argument (used for more complex distributions).
+
+ This option corresponds to ``mask_other`` from ``fairseq``.
+
+ mask_length (int):
+ The lengths of the mask.
+
+ This option corresponds to ``mask_length`` from ``fairseq``.
+
+ no_mask_overlap (bool):
+ Whether to allow masks to overlap.
+
+ This option corresponds to ``no_mask_overlap`` from ``fairseq``.
+
+ mask_min_space (int):
+ Minimum space between spans (if no overlap is enabled).
+
+ This option corresponds to ``mask_min_space`` from ``fairseq``.
+
+ mask_channel_prob: (float):
+ The probability of replacing a feature with 0.
+
+ This option corresponds to ``mask_channel_prob`` from ``fairseq``.
+
+ mask_channel_selection (str):
+ How to choose the mask length for channel masking. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
+
+ This option corresponds to ``mask_channel_selection`` from ``fairseq``.
+
+ mask_channel_other (float):
+ Secondary mask argument for channel masking(used for more complex distributions).
+
+ This option corresponds to ``mask_channel_other`` from ``fairseq``.
+
+ mask_channel_length (int):
+ Minimum space between spans (if no overlap is enabled) for channel masking.
+
+ This option corresponds to ``mask_channel_length`` from ``fairseq``.
+
+ no_mask_channel_overlap (bool):
+ Whether to allow channel masks to overlap.
+
+ This option corresponds to ``no_mask_channel_overlap`` from ``fairseq``.
+
+ mask_channel_min_space (int):
+ Minimum space between spans for channel masking(if no overlap is enabled).
+
+ This option corresponds to ``mask_channel_min_space`` from ``fairseq``.
+
+ skip_masked (bool):
+ If True, skip computing losses over masked frames.
+
+ This option corresponds to ``skip_masked`` from ``fairseq``.
+
+ skip_nomask (bool):
+ If True, skip computing losses over unmasked frames.
+
+ This option corresponds to ``skip_nomask`` from ``fairseq``.
+
+ num_classes (int):
+ The number of classes in the labels.
+
+ final_dim (int):
+ Project final representations and targets to `final_dim`.
+
+ This option corresponds to ``final_dim`` from ``fairseq``.
+
+ feature_grad_mult (float or None):
+ The factor to scale the convolutional feature extraction layer gradients by.
+ The scale factor will not affect the forward pass.
+
+ This option corresponds to ``feature_grad_mult`` from ``fairseq``.
+
+ Returns:
+ HuBERTPretrainModel:
+ The resulting model.
+ """ # noqa: E501
+ if extractor_conv_layer_config is None:
+ extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
+
+ feature_extractor = components._get_feature_extractor(
+ extractor_mode, extractor_conv_layer_config, extractor_conv_bias
+ )
+ encoder = components._get_encoder(
+ in_features=extractor_conv_layer_config[-1][0],
+ embed_dim=encoder_embed_dim,
+ dropout_input=encoder_projection_dropout,
+ pos_conv_kernel=encoder_pos_conv_kernel,
+ pos_conv_groups=encoder_pos_conv_groups,
+ num_layers=encoder_num_layers,
+ num_heads=encoder_num_heads,
+ attention_dropout=encoder_attention_dropout,
+ ff_interm_features=encoder_ff_interm_features,
+ ff_interm_dropout=encoder_ff_interm_dropout,
+ dropout=encoder_dropout,
+ layer_norm_first=encoder_layer_norm_first,
+ layer_drop=encoder_layer_drop,
+ )
+ wav2vec2 = Wav2Vec2Model(feature_extractor, encoder)
+ mask_generator = components.MaskGenerator(
+ encoder_embed_dim,
+ mask_prob,
+ mask_selection,
+ mask_other,
+ mask_length,
+ no_mask_overlap,
+ mask_min_space,
+ mask_channel_prob,
+ mask_channel_selection,
+ mask_channel_other,
+ mask_channel_length,
+ no_mask_channel_overlap,
+ mask_channel_min_space,
+ )
+ logit_generator = components.LogitGenerator(
+ encoder_embed_dim,
+ num_classes,
+ final_dim,
+ skip_masked,
+ skip_nomask,
+ )
+ model = HuBERTPretrainModel(
+ wav2vec2=wav2vec2,
+ mask_generator=mask_generator,
+ logit_generator=logit_generator,
+ feature_grad_mult=feature_grad_mult,
+ )
+ # initialize the model for pre-training
+ model.apply(_init_hubert_pretrain_model)
+ return model
+
+
+def hubert_pretrain_base(
+ encoder_projection_dropout: float = 0.1,
+ encoder_attention_dropout: float = 0.1,
+ encoder_ff_interm_dropout: float = 0.0,
+ encoder_dropout: float = 0.1,
+ encoder_layer_drop: float = 0.05,
+ mask_prob: float = 0.8,
+ mask_channel_prob: float = 0.0,
+ mask_channel_length: int = 10,
+ feature_grad_mult: Optional[float] = 0.1,
+ num_classes: int = 100,
+) -> HuBERTPretrainModel:
+ """Builds "base" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining.
+
+ Args:
+ encoder_projection_dropout (float):
+ See :py:func:`hubert_pretrain_model`.
+ encoder_attention_dropout (float):
+ See :py:func:`hubert_pretrain_model`.
+ encoder_ff_interm_dropout (float):
+ See :py:func:`hubert_pretrain_model`.
+ encoder_dropout (float):
+ See :py:func:`hubert_pretrain_model`.
+ encoder_layer_drop (float):
+ See :py:func:`hubert_pretrain_model`.
+ mask_prob (float):
+ See :py:func:`hubert_pretrain_model`.
+ mask_channel_prob (float):
+ See :py:func:`hubert_pretrain_model`.
+ mask_channel_length (int):
+ See :py:func:`hubert_pretrain_model`.
+ feature_grad_mult (float or None):
+ See :py:func:`hubert_pretrain_model`.
+ num_classes (int, optional):
+ See :py:func:`hubert_pretrain_model`.
+
+ Returns:
+ HuBERTPretrainModel:
+ The resulting model.
+ """ # noqa: E501
+ return hubert_pretrain_model(
+ extractor_mode="group_norm",
+ extractor_conv_layer_config=None,
+ extractor_conv_bias=False,
+ encoder_embed_dim=768,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_pos_conv_kernel=128,
+ encoder_pos_conv_groups=16,
+ encoder_num_layers=12,
+ encoder_num_heads=12,
+ encoder_attention_dropout=encoder_attention_dropout,
+ encoder_ff_interm_features=3072,
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
+ encoder_dropout=encoder_dropout,
+ encoder_layer_norm_first=False,
+ encoder_layer_drop=encoder_layer_drop,
+ mask_prob=mask_prob,
+ mask_selection="static",
+ mask_other=0.0,
+ mask_length=10,
+ no_mask_overlap=False,
+ mask_min_space=1,
+ mask_channel_prob=mask_channel_prob,
+ mask_channel_selection="static",
+ mask_channel_other=0.0,
+ mask_channel_length=mask_channel_length,
+ no_mask_channel_overlap=False,
+ mask_channel_min_space=1,
+ skip_masked=False,
+ skip_nomask=False,
+ num_classes=num_classes,
+ final_dim=256,
+ feature_grad_mult=feature_grad_mult,
+ )
+
+
+def hubert_pretrain_large(
+ encoder_projection_dropout: float = 0.0,
+ encoder_attention_dropout: float = 0.0,
+ encoder_ff_interm_dropout: float = 0.0,
+ encoder_dropout: float = 0.0,
+ encoder_layer_drop: float = 0.0,
+ mask_prob: float = 0.8,
+ mask_channel_prob: float = 0.0,
+ mask_channel_length: int = 10,
+ feature_grad_mult: Optional[float] = None,
+) -> HuBERTPretrainModel:
+ """Builds "large" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining.
+
+ Args:
+ encoder_projection_dropout (float):
+ See :py:func:`hubert_pretrain_model`.
+ encoder_attention_dropout (float):
+ See :py:func:`hubert_pretrain_model`.
+ encoder_ff_interm_dropout (float):
+ See :py:func:`hubert_pretrain_model`.
+ encoder_dropout (float):
+ See :py:func:`hubert_pretrain_model`.
+ encoder_layer_drop (float):
+ See :py:func:`hubert_pretrain_model`.
+ mask_prob (float):
+ See :py:func:`hubert_pretrain_model`.
+ mask_channel_prob (float):
+ See :py:func:`hubert_pretrain_model`.
+ mask_channel_length (int):
+ See :py:func:`hubert_pretrain_model`.
+ feature_grad_mult (float or None):
+ See :py:func:`hubert_pretrain_model`.
+
+ Returns:
+ HuBERTPretrainModel:
+ The resulting model.
+ """ # noqa: E501
+ return hubert_pretrain_model(
+ extractor_mode="layer_norm",
+ extractor_conv_layer_config=None,
+ extractor_conv_bias=False,
+ encoder_embed_dim=1024,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_pos_conv_kernel=128,
+ encoder_pos_conv_groups=16,
+ encoder_num_layers=24,
+ encoder_num_heads=16,
+ encoder_attention_dropout=encoder_attention_dropout,
+ encoder_ff_interm_features=4096,
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
+ encoder_dropout=encoder_dropout,
+ encoder_layer_norm_first=True,
+ encoder_layer_drop=encoder_layer_drop,
+ mask_prob=mask_prob,
+ mask_selection="static",
+ mask_other=0.0,
+ mask_length=10,
+ no_mask_overlap=False,
+ mask_min_space=1,
+ mask_channel_prob=mask_channel_prob,
+ mask_channel_selection="static",
+ mask_channel_other=0.0,
+ mask_channel_length=mask_channel_length,
+ no_mask_channel_overlap=False,
+ mask_channel_min_space=1,
+ skip_masked=False,
+ skip_nomask=False,
+ num_classes=500,
+ final_dim=768,
+ feature_grad_mult=feature_grad_mult,
+ )
+
+
+def hubert_pretrain_xlarge(
+ encoder_projection_dropout: float = 0.0,
+ encoder_attention_dropout: float = 0.0,
+ encoder_ff_interm_dropout: float = 0.0,
+ encoder_dropout: float = 0.0,
+ encoder_layer_drop: float = 0.0,
+ mask_prob: float = 0.8,
+ mask_channel_prob: float = 0.0,
+ mask_channel_length: int = 10,
+ feature_grad_mult: Optional[float] = None,
+) -> HuBERTPretrainModel:
+ """Builds "extra large" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining.
+
+ Args:
+ encoder_projection_dropout (float):
+ See :py:func:`hubert_pretrain_model`.
+ encoder_attention_dropout (float):
+ See :py:func:`hubert_pretrain_model`.
+ encoder_ff_interm_dropout (float):
+ See :py:func:`hubert_pretrain_model`.
+ encoder_dropout (float):
+ See :py:func:`hubert_pretrain_model`.
+ encoder_layer_drop (float):
+ See :py:func:`hubert_pretrain_model`.
+ mask_prob (float):
+ See :py:func:`hubert_pretrain_model`.
+ mask_channel_prob (float):
+ See :py:func:`hubert_pretrain_model`.
+ mask_channel_length (int):
+ See :py:func:`hubert_pretrain_model`.
+ feature_grad_mult (float or None):
+ See :py:func:`hubert_pretrain_model`.
+
+ Returns:
+ HuBERTPretrainModel:
+ The resulting model.
+ """ # noqa: E501
+ return hubert_pretrain_model(
+ extractor_mode="layer_norm",
+ extractor_conv_layer_config=None,
+ extractor_conv_bias=False,
+ encoder_embed_dim=1280,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_pos_conv_kernel=128,
+ encoder_pos_conv_groups=16,
+ encoder_num_layers=48,
+ encoder_num_heads=16,
+ encoder_attention_dropout=encoder_attention_dropout,
+ encoder_ff_interm_features=5120,
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
+ encoder_dropout=encoder_dropout,
+ encoder_layer_norm_first=True,
+ encoder_layer_drop=encoder_layer_drop,
+ mask_prob=mask_prob,
+ mask_selection="static",
+ mask_other=0.0,
+ mask_length=10,
+ no_mask_overlap=False,
+ mask_min_space=1,
+ mask_channel_prob=mask_channel_prob,
+ mask_channel_selection="static",
+ mask_channel_other=0.0,
+ mask_channel_length=mask_channel_length,
+ no_mask_channel_overlap=False,
+ mask_channel_min_space=1,
+ skip_masked=False,
+ skip_nomask=False,
+ num_classes=500,
+ final_dim=1024,
+ feature_grad_mult=feature_grad_mult,
+ )
+
+
+def wavlm_model(
+ extractor_mode: str,
+ extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
+ extractor_conv_bias: bool,
+ encoder_embed_dim: int,
+ encoder_projection_dropout: float,
+ encoder_pos_conv_kernel: int,
+ encoder_pos_conv_groups: int,
+ encoder_num_layers: int,
+ encoder_num_heads: int,
+ encoder_num_buckets: int,
+ encoder_max_distance: int,
+ encoder_attention_dropout: float,
+ encoder_ff_interm_features: int,
+ encoder_ff_interm_dropout: float,
+ encoder_dropout: float,
+ encoder_layer_norm_first: bool,
+ encoder_layer_drop: float,
+ aux_num_out: Optional[int],
+) -> Wav2Vec2Model:
+ """Builds custom WaveLM model :cite:`chen2022wavlm`. The architecture is compatible
+ with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output object is
+ :class:`~torchaudio.models.Wav2Vec2Model`. Most of the arguments have the same meaning
+ as in :py:func:`~torchaudio.models.wav2vec2_model` so please refer there for documentation.
+
+ Args:
+ extractor_mode (str): Operation mode of feature extractor.
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ extractor_conv_layer_config (list of integer tuples or None):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ extractor_conv_bias (bool):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ encoder_embed_dim (int):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ encoder_projection_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ encoder_pos_conv_kernel (int):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ encoder_pos_conv_groups (int):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ encoder_num_layers (int):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ encoder_num_heads (int):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ encoder_num_buckets (int):
+ Number of buckets for relative position embedding.
+ encoder_max_distance (int):
+ Maximum distance for relative position embedding.
+
+ encoder_attention_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ encoder_ff_interm_features (int):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ encoder_ff_interm_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ encoder_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ encoder_layer_norm_first (bool):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ encoder_layer_drop (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ aux_num_out (int or None):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting model.
+ """
+ if extractor_conv_layer_config is None:
+ extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
+
+ feature_extractor = components._get_feature_extractor(
+ extractor_mode, extractor_conv_layer_config, extractor_conv_bias
+ )
+ encoder = components._get_wavlm_encoder(
+ in_features=extractor_conv_layer_config[-1][0],
+ embed_dim=encoder_embed_dim,
+ dropout_input=encoder_projection_dropout,
+ pos_conv_kernel=encoder_pos_conv_kernel,
+ pos_conv_groups=encoder_pos_conv_groups,
+ num_layers=encoder_num_layers,
+ num_heads=encoder_num_heads,
+ num_buckets=encoder_num_buckets,
+ max_distance=encoder_max_distance,
+ attention_dropout=encoder_attention_dropout,
+ ff_interm_features=encoder_ff_interm_features,
+ ff_interm_dropout=encoder_ff_interm_dropout,
+ dropout=encoder_dropout,
+ layer_norm_first=encoder_layer_norm_first,
+ layer_drop=encoder_layer_drop,
+ )
+ aux = None
+ if aux_num_out is not None:
+ aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out)
+ return Wav2Vec2Model(feature_extractor, encoder, aux)
+
+
+def wavlm_base(
+ encoder_projection_dropout: float = 0.1,
+ encoder_attention_dropout: float = 0.1,
+ encoder_ff_interm_dropout: float = 0.1,
+ encoder_dropout: float = 0.1,
+ encoder_layer_drop: float = 0.1,
+ aux_num_out: Optional[int] = None,
+) -> Wav2Vec2Model:
+ """Builds "base" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible
+ with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
+ :class:`~torchaudio.models.Wav2Vec2Model`.
+
+ Args:
+ encoder_projection_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_attention_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_ff_interm_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_layer_drop (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ aux_num_out (int, optional):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting model.
+ """
+ return wavlm_model(
+ extractor_mode="group_norm",
+ extractor_conv_layer_config=None,
+ extractor_conv_bias=False,
+ encoder_embed_dim=768,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_pos_conv_kernel=128,
+ encoder_pos_conv_groups=16,
+ encoder_num_layers=12,
+ encoder_num_heads=12,
+ encoder_num_buckets=320,
+ encoder_max_distance=800,
+ encoder_attention_dropout=encoder_attention_dropout,
+ encoder_ff_interm_features=3072,
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
+ encoder_dropout=encoder_dropout,
+ encoder_layer_norm_first=False,
+ encoder_layer_drop=encoder_layer_drop,
+ aux_num_out=aux_num_out,
+ )
+
+
+def wavlm_large(
+ encoder_projection_dropout: float = 0.1,
+ encoder_attention_dropout: float = 0.1,
+ encoder_ff_interm_dropout: float = 0.0,
+ encoder_dropout: float = 0.1,
+ encoder_layer_drop: float = 0.1,
+ aux_num_out: Optional[int] = None,
+) -> Wav2Vec2Model:
+ """Builds "large" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible
+ with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
+ :class:`~torchaudio.models.Wav2Vec2Model`.
+
+ Args:
+ encoder_projection_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_attention_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_ff_interm_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_layer_drop (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ aux_num_out (int, optional):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting model.
+ """
+ return wavlm_model(
+ extractor_mode="layer_norm",
+ extractor_conv_layer_config=None,
+ extractor_conv_bias=False,
+ encoder_embed_dim=1024,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_pos_conv_kernel=128,
+ encoder_pos_conv_groups=16,
+ encoder_num_layers=24,
+ encoder_num_heads=16,
+ encoder_num_buckets=320,
+ encoder_max_distance=800,
+ encoder_attention_dropout=encoder_attention_dropout,
+ encoder_ff_interm_features=4096,
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
+ encoder_dropout=encoder_dropout,
+ encoder_layer_norm_first=True,
+ encoder_layer_drop=encoder_layer_drop,
+ aux_num_out=aux_num_out,
+ )
+
+
+def wav2vec2_xlsr_300m(
+ encoder_projection_dropout: float = 0.0,
+ encoder_attention_dropout: float = 0.0,
+ encoder_ff_interm_dropout: float = 0.0,
+ encoder_dropout: float = 0.0,
+ encoder_layer_drop: float = 0.0,
+ aux_num_out: Optional[int] = None,
+) -> Wav2Vec2Model:
+ """Builds XLS-R model :cite:`babu2021xls` with 300 millions of parameters. The architecture is compatible
+ with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
+ :class:`~torchaudio.models.Wav2Vec2Model`.
+
+ Args:
+ encoder_projection_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_attention_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_ff_interm_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_layer_drop (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ aux_num_out (int, optional):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting model.
+ """
+ return wav2vec2_model(
+ extractor_mode="layer_norm",
+ extractor_conv_layer_config=None,
+ extractor_conv_bias=True,
+ encoder_embed_dim=1024,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_pos_conv_kernel=128,
+ encoder_pos_conv_groups=16,
+ encoder_num_layers=24,
+ encoder_num_heads=16,
+ encoder_attention_dropout=encoder_attention_dropout,
+ encoder_ff_interm_features=4096,
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
+ encoder_dropout=encoder_dropout,
+ encoder_layer_norm_first=True,
+ encoder_layer_drop=encoder_layer_drop,
+ aux_num_out=aux_num_out,
+ )
+
+
+def wav2vec2_xlsr_1b(
+ encoder_projection_dropout: float = 0.1,
+ encoder_attention_dropout: float = 0.0,
+ encoder_ff_interm_dropout: float = 0.0,
+ encoder_dropout: float = 0.0,
+ encoder_layer_drop: float = 0.0,
+ aux_num_out: Optional[int] = None,
+) -> Wav2Vec2Model:
+ """Builds XLS-R model :cite:`babu2021xls` with 1 billion of parameters. The architecture is compatible
+ with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
+ :class:`~torchaudio.models.Wav2Vec2Model`.
+
+ Args:
+ encoder_projection_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_attention_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_ff_interm_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_layer_drop (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ aux_num_out (int, optional):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting model.
+ """
+ return wav2vec2_model(
+ extractor_mode="layer_norm",
+ extractor_conv_layer_config=None,
+ extractor_conv_bias=True,
+ encoder_embed_dim=1280,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_pos_conv_kernel=128,
+ encoder_pos_conv_groups=16,
+ encoder_num_layers=48,
+ encoder_num_heads=16,
+ encoder_attention_dropout=encoder_attention_dropout,
+ encoder_ff_interm_features=5120,
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
+ encoder_dropout=encoder_dropout,
+ encoder_layer_norm_first=True,
+ encoder_layer_drop=encoder_layer_drop,
+ aux_num_out=aux_num_out,
+ )
+
+
+def wav2vec2_xlsr_2b(
+ encoder_projection_dropout: float = 0.1,
+ encoder_attention_dropout: float = 0.0,
+ encoder_ff_interm_dropout: float = 0.0,
+ encoder_dropout: float = 0.0,
+ encoder_layer_drop: float = 0.0,
+ aux_num_out: Optional[int] = None,
+) -> Wav2Vec2Model:
+ """Builds XLS-R model :cite:`babu2021xls` with 2 billions of parameters. The architecture is compatible
+ with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
+ :class:`~torchaudio.models.Wav2Vec2Model`.
+
+ Args:
+ encoder_projection_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_attention_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_ff_interm_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_dropout (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ encoder_layer_drop (float):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+ aux_num_out (int, optional):
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting model.
+ """
+ return wav2vec2_model(
+ extractor_mode="layer_norm",
+ extractor_conv_layer_config=None,
+ extractor_conv_bias=True,
+ encoder_embed_dim=1920,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_pos_conv_kernel=128,
+ encoder_pos_conv_groups=16,
+ encoder_num_layers=48,
+ encoder_num_heads=16,
+ encoder_attention_dropout=encoder_attention_dropout,
+ encoder_ff_interm_features=7680,
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
+ encoder_dropout=encoder_dropout,
+ encoder_layer_norm_first=True,
+ encoder_layer_drop=encoder_layer_drop,
+ aux_num_out=aux_num_out,
+ )
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/__init__.py b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a536ee2c28b470db9cc6b4f6d1dbfa664b3e17df
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/__init__.py
@@ -0,0 +1,7 @@
+from .import_fairseq import import_fairseq_model
+from .import_huggingface import import_huggingface_model
+
+__all__ = [
+ "import_huggingface_model",
+ "import_fairseq_model",
+]
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d15057842e133442b8fa50d76e0eea64a10eb41
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_fairseq.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_fairseq.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2fb4b919eddb7f482a964a2d7ba8d3be2127f7a0
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_fairseq.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_huggingface.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_huggingface.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4b31a3ae6ba37b16d53d8f9297bcb2b2e2b4eef
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_huggingface.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/import_fairseq.py b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/import_fairseq.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5873446f1553cc6b7bf17a8e421ad1160772b57
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/import_fairseq.py
@@ -0,0 +1,213 @@
+"""Import fariseq's wav2vec2.0 pretrained weights to torchaudios's format.
+
+For this module to work, you need `fairseq`.
+"""
+import re
+
+from torch.nn import Module
+
+from ..model import wav2vec2_model, Wav2Vec2Model
+
+
+def _parse_config(w2v_model):
+ encoder = w2v_model.encoder
+ conv_layers = w2v_model.feature_extractor.conv_layers
+
+ extractor_mode = "layer_norm"
+ if "GroupNorm" in conv_layers[0][2].__class__.__name__:
+ extractor_mode = "group_norm"
+ else:
+ extractor_mode = "layer_norm"
+
+ conv_layer_config = [(l[0].out_channels, l[0].kernel_size[0], l[0].stride[0]) for l in conv_layers]
+
+ if all(l[0].bias is None for l in conv_layers):
+ conv_bias = False
+ elif all(l[0].bias is not None for l in conv_layers):
+ conv_bias = True
+ else:
+ raise ValueError("Either all the convolutions layers have bias term or none of them should.")
+
+ config = {
+ "extractor_mode": extractor_mode,
+ "extractor_conv_layer_config": conv_layer_config,
+ "extractor_conv_bias": conv_bias,
+ "encoder_embed_dim": w2v_model.post_extract_proj.out_features,
+ "encoder_projection_dropout": w2v_model.dropout_input.p,
+ "encoder_pos_conv_kernel": encoder.pos_conv[0].kernel_size[0],
+ "encoder_pos_conv_groups": encoder.pos_conv[0].groups,
+ "encoder_num_layers": len(encoder.layers),
+ "encoder_num_heads": encoder.layers[0].self_attn.num_heads,
+ "encoder_attention_dropout": encoder.layers[0].self_attn.dropout_module.p,
+ "encoder_ff_interm_features": encoder.layers[0].fc1.out_features,
+ "encoder_ff_interm_dropout": encoder.layers[0].dropout2.p,
+ "encoder_dropout": encoder.layers[0].dropout3.p,
+ "encoder_layer_norm_first": encoder.layer_norm_first,
+ "encoder_layer_drop": encoder.layerdrop,
+ }
+ return config
+
+
+def _map_key(key):
+ key_ = key
+ if key.startswith("w2v_model."):
+ key = key.replace("w2v_model.", "")
+ if re.match(r"(mask_emb|quantizer|project_q|final_proj|mask_emb)", key):
+ return None
+ # Feature Extractor
+ # Group norm when "extractor_mode" is "default".
+ # (Only the first layer)
+ # "conv_layers.0.2.weight" -> "conv_layers.0.layer_norm.weight"
+ # "conv_layers.0.2.bias" -> "conv_layers.0.layer_norm.bias"
+ match = re.match(r"feature_extractor\.conv_layers\.0\.2\.(weight|bias)", key)
+ if match:
+ return f"feature_extractor.conv_layers.0.layer_norm.{match.group(1)}"
+ # Convolutions
+ # "conv_layers.X.0.weight" -> "conv_layers.X.conv.weight"
+ # "conv_layers.X.0.bias" -> "conv_layers.X.conv.bias"
+ match = re.match(r"feature_extractor\.conv_layers\.(\d+)\.0\.(weight|bias)", key)
+ if match:
+ return f"feature_extractor.conv_layers.{match.group(1)}.conv.{match.group(2)}"
+ # Layer norm when "extractor_mode" is "layer_norm".
+ # "conv_layers.X.2.1.weight" -> "conv_layers.X.layer_norm.weight"
+ # "conv_layers.X.2.1.bias" -> "conv_layers.X.layer_norm.bias"
+ match = re.match(r"feature_extractor\.conv_layers\.(\d+)\.2\.1\.(weight|bias)", key)
+ if match:
+ return f"feature_extractor.conv_layers.{match.group(1)}.layer_norm.{match.group(2)}"
+ match = re.match(r"post_extract_proj\.(weight|bias)", key)
+ # Encoder - Feature projection
+ if match:
+ return f"encoder.feature_projection.projection.{match.group(1)}"
+ match = re.match(r"layer_norm\.(weight|bias)", key)
+ if match:
+ return f"encoder.feature_projection.layer_norm.{match.group(1)}"
+ # Encoder - Transformer - Convolutional positional embedding
+ match = re.match(r"encoder\.pos_conv\.0\.(bias|weight_g|weight_v)", key)
+ if match:
+ return f"encoder.transformer.pos_conv_embed.conv.{match.group(1)}"
+ match = re.match(r"encoder\.layer_norm\.(weight|bias)", key)
+ if match:
+ return f"encoder.transformer.layer_norm.{match.group(1)}"
+ # Encoder - Transformer - Self attention layers
+ match = re.match(r"encoder\.layers\.(\d+)\.self_attn\.((k_|v_|q_|out_)proj\.(weight|bias))", key)
+ if match:
+ return f"encoder.transformer.layers.{match.group(1)}.attention.{match.group(2)}"
+ match = re.match(r"encoder\.layers\.(\d+)\.self_attn_layer_norm\.(weight|bias)", key)
+ if match:
+ return f"encoder.transformer.layers.{match.group(1)}.layer_norm.{match.group(2)}"
+ match = re.match(r"encoder\.layers\.(\d+)\.fc1\.(weight|bias)", key)
+ if match:
+ return f"encoder.transformer.layers.{match.group(1)}.feed_forward.intermediate_dense.{match.group(2)}"
+ match = re.match(r"encoder\.layers\.(\d+)\.fc2\.(weight|bias)", key)
+ if match:
+ return f"encoder.transformer.layers.{match.group(1)}.feed_forward.output_dense.{match.group(2)}"
+ match = re.match(r"encoder\.layers\.(\d+)\.final_layer_norm\.(weight|bias)", key)
+ if match:
+ return f"encoder.transformer.layers.{match.group(1)}.final_layer_norm.{match.group(2)}"
+ match = re.match(r"proj\.(weight|bias)", key)
+ # Auxiliary Module
+ # Only relevant when loading fine-tuned models
+ if match:
+ return f"aux.{match.group(1)}"
+ # HuBERT Extension
+ if key in ["label_embs_concat"]:
+ return key
+ raise ValueError(f"Unexpected key: {key_}")
+
+
+def _convert_state_dict(state_dict):
+ converted = {}
+ for k, v in state_dict.items():
+ k = _map_key(k)
+ if k is not None:
+ converted[k] = v
+ return converted
+
+
+def import_fairseq_model(original: Module) -> Wav2Vec2Model:
+ """Builds :class:`Wav2Vec2Model` from the corresponding model object of
+ `fairseq `_.
+
+ Args:
+ original (torch.nn.Module):
+ An instance of fairseq's Wav2Vec2.0 or HuBERT model.
+ One of ``fairseq.models.wav2vec.wav2vec2_asr.Wav2VecEncoder``,
+ ``fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model`` or
+ ``fairseq.models.hubert.hubert_asr.HubertEncoder``.
+
+ Returns:
+ Wav2Vec2Model: Imported model.
+
+ Example - Loading pretrain-only model
+ >>> from torchaudio.models.wav2vec2.utils import import_fairseq_model
+ >>>
+ >>> # Load model using fairseq
+ >>> model_file = 'wav2vec_small.pt'
+ >>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file])
+ >>> original = model[0]
+ >>> imported = import_fairseq_model(original)
+ >>>
+ >>> # Perform feature extraction
+ >>> waveform, _ = torchaudio.load('audio.wav')
+ >>> features, _ = imported.extract_features(waveform)
+ >>>
+ >>> # Compare result with the original model from fairseq
+ >>> reference = original.feature_extractor(waveform).transpose(1, 2)
+ >>> torch.testing.assert_allclose(features, reference)
+
+ Example - Fine-tuned model
+ >>> from torchaudio.models.wav2vec2.utils import import_fairseq_model
+ >>>
+ >>> # Load model using fairseq
+ >>> model_file = 'wav2vec_small_960h.pt'
+ >>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file])
+ >>> original = model[0]
+ >>> imported = import_fairseq_model(original.w2v_encoder)
+ >>>
+ >>> # Perform encoding
+ >>> waveform, _ = torchaudio.load('audio.wav')
+ >>> emission, _ = imported(waveform)
+ >>>
+ >>> # Compare result with the original model from fairseq
+ >>> mask = torch.zeros_like(waveform)
+ >>> reference = original(waveform, mask)['encoder_out'].transpose(0, 1)
+ >>> torch.testing.assert_allclose(emission, reference)
+ """
+ class_ = original.__class__.__name__
+ if class_ == "Wav2Vec2Model":
+ return _import_wav2vec2_pretraining(original)
+ if class_ == "Wav2VecEncoder":
+ return _import_wav2vec2_finetuning(original)
+ if class_ == "HubertModel":
+ return _import_hubert_pretraining(original)
+ if class_ == "HubertEncoder":
+ return _import_hubert_finetuning(original)
+ raise ValueError(f"Expected an instance of `Wav2Vec2Model` or `Wav2VecEncoder`. Found: {class_}")
+
+
+def _import_wav2vec2_finetuning(original: Module) -> Wav2Vec2Model:
+ config = _parse_config(original.w2v_model)
+ model = wav2vec2_model(**config, aux_num_out=original.proj.out_features)
+ model.load_state_dict(_convert_state_dict(original.state_dict()))
+ return model
+
+
+def _import_wav2vec2_pretraining(original: Module) -> Wav2Vec2Model:
+ config = _parse_config(original)
+ model = wav2vec2_model(**config, aux_num_out=None)
+ model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False)
+ return model
+
+
+def _import_hubert_finetuning(original: Module) -> Wav2Vec2Model:
+ config = _parse_config(original.w2v_model)
+ model = wav2vec2_model(**config, aux_num_out=original.proj.out_features)
+ model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False)
+ return model
+
+
+def _import_hubert_pretraining(original: Module) -> Wav2Vec2Model:
+ config = _parse_config(original)
+ model = wav2vec2_model(**config, aux_num_out=None)
+ model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False)
+ return model
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/import_huggingface.py b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/import_huggingface.py
new file mode 100644
index 0000000000000000000000000000000000000000..38703408f01d52b8259f39921202ccbd19a24a3f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/utils/import_huggingface.py
@@ -0,0 +1,134 @@
+"""Import Hugging Face transformers's wav2vec2.0 pretrained weights to torchaudios's format.
+"""
+import logging
+from typing import Any, Dict
+
+import torch
+from torch.nn import Module
+
+from ..model import wav2vec2_model, Wav2Vec2Model, wavlm_model
+
+_LG = logging.getLogger(__name__)
+
+
+def _get_config(cfg):
+ config = {
+ "extractor_mode": f"{cfg.feat_extract_norm}_norm",
+ "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)),
+ "extractor_conv_bias": cfg.conv_bias,
+ "encoder_embed_dim": cfg.hidden_size,
+ "encoder_projection_dropout": cfg.feat_proj_dropout,
+ "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings,
+ "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups,
+ "encoder_num_layers": cfg.num_hidden_layers,
+ "encoder_num_heads": cfg.num_attention_heads,
+ "encoder_attention_dropout": cfg.attention_dropout,
+ "encoder_ff_interm_features": cfg.intermediate_size,
+ "encoder_ff_interm_dropout": cfg.activation_dropout,
+ "encoder_dropout": cfg.hidden_dropout,
+ "encoder_layer_norm_first": cfg.do_stable_layer_norm,
+ "encoder_layer_drop": cfg.layerdrop,
+ }
+ return config
+
+
+def _get_config_wavlm(cfg):
+ config = {
+ "extractor_mode": f"{cfg.feat_extract_norm}_norm",
+ "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)),
+ "extractor_conv_bias": cfg.conv_bias,
+ "encoder_embed_dim": cfg.hidden_size,
+ "encoder_projection_dropout": cfg.feat_proj_dropout,
+ "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings,
+ "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups,
+ "encoder_num_layers": cfg.num_hidden_layers,
+ "encoder_num_heads": cfg.num_attention_heads,
+ "encoder_num_buckets": cfg.num_buckets,
+ "encoder_max_distance": cfg.max_bucket_distance,
+ "encoder_attention_dropout": cfg.attention_dropout,
+ "encoder_ff_interm_features": cfg.intermediate_size,
+ "encoder_ff_interm_dropout": cfg.activation_dropout,
+ "encoder_dropout": cfg.hidden_dropout,
+ "encoder_layer_norm_first": cfg.do_stable_layer_norm,
+ "encoder_layer_drop": cfg.layerdrop,
+ }
+ return config
+
+
+def _build(config, original):
+ is_for_ctc = original.__class__.__name__ in ["Wav2Vec2ForCTC", "WavLMForCTC"]
+ if is_for_ctc:
+ aux_num_out = original.config.vocab_size
+ wav2vec2 = original.wav2vec2
+ else:
+ _LG.warning(
+ "The model is not an instance of Wav2Vec2ForCTC or WavLMForCTC. " '"lm_head" module is not imported.'
+ )
+ aux_num_out = None
+ wav2vec2 = original
+ is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"]
+ if is_wavlm:
+ imported = wavlm_model(**config, aux_num_out=aux_num_out)
+ else:
+ imported = wav2vec2_model(**config, aux_num_out=aux_num_out)
+ imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict())
+ imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
+ encoder_state_dict = wav2vec2.encoder.state_dict()
+ if is_wavlm: # Rename paramaters of linear transformations for compatibility with the HF model
+ transform_wavlm_encoder_state(encoder_state_dict, config["encoder_num_layers"])
+ imported.encoder.transformer.load_state_dict(encoder_state_dict)
+ if is_for_ctc:
+ imported.aux.load_state_dict(original.lm_head.state_dict())
+ return imported
+
+
+def transform_wavlm_encoder_state(state: Dict[str, Any], encoder_num_layers: int):
+ """Converts WavLM encoder state from HuggingFace format. In particular, concatenates linear projection weights and
+ biases to align with the structure of ``torch.nn.MultiheadAttention``.
+ """
+ for i in range(encoder_num_layers):
+ q_proj_bias = state.pop(f"layers.{i}.attention.q_proj.bias")
+ k_proj_bias = state.pop(f"layers.{i}.attention.k_proj.bias")
+ v_proj_bias = state.pop(f"layers.{i}.attention.v_proj.bias")
+ q_proj_weight = state.pop(f"layers.{i}.attention.q_proj.weight")
+ k_proj_weight = state.pop(f"layers.{i}.attention.k_proj.weight")
+ v_proj_weight = state.pop(f"layers.{i}.attention.v_proj.weight")
+ state[f"layers.{i}.attention.attention.in_proj_bias"] = torch.cat((q_proj_bias, k_proj_bias, v_proj_bias))
+ state[f"layers.{i}.attention.attention.in_proj_weight"] = torch.cat(
+ (q_proj_weight, k_proj_weight, v_proj_weight)
+ )
+
+ state[f"layers.{i}.attention.attention.out_proj.weight"] = state.pop(f"layers.{i}.attention.out_proj.weight")
+ state[f"layers.{i}.attention.attention.out_proj.bias"] = state.pop(f"layers.{i}.attention.out_proj.bias")
+
+
+def import_huggingface_model(original: Module) -> Wav2Vec2Model:
+ """Builds :class:`Wav2Vec2Model` from the corresponding model object of
+ `Transformers `_.
+
+ Args:
+ original (torch.nn.Module): An instance of ``Wav2Vec2ForCTC`` from ``transformers``.
+
+ Returns:
+ Wav2Vec2Model: Imported model.
+
+ Example
+ >>> from torchaudio.models.wav2vec2.utils import import_huggingface_model
+ >>>
+ >>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
+ >>> model = import_huggingface_model(original)
+ >>>
+ >>> waveforms, _ = torchaudio.load("audio.wav")
+ >>> logits, _ = model(waveforms)
+ """
+ _LG.info("Importing model.")
+ _LG.info("Loading model configuration.")
+ is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"]
+ if is_wavlm:
+ config = _get_config_wavlm(original.config)
+ else:
+ config = _get_config(original.config)
+ _LG.debug(" - config: %s", config)
+ _LG.info("Building model.")
+ imported = _build(config, original)
+ return imported
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/wavlm_attention.py b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/wavlm_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fcff2a5679511c48675b894bc3f3efd501b6d0a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/wav2vec2/wavlm_attention.py
@@ -0,0 +1,214 @@
+"""
+The MIT License (MIT)
+
+Copyright (c) Microsoft Corporation
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+"""
+
+import math
+from typing import Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+
+
+class WavLMSelfAttention(nn.Module):
+ """Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`.
+ Wraps around ``torch.nn.MultiheadAttention``, creating relaive position embeddings and passing them to multi-headed
+ attention as a mask.
+ Source: https://github.com/microsoft/unilm/blob/2d8302f09c99bca2b82e6e868d81d4281cceebc8/wavlm/modules.py#L303-L763
+
+ Args:
+ embed_dim (int): Total dimension of the model.
+ num_heads (int): The number of heads.
+ dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``)
+ bias (bool, optional): If ``True``, add bias to input / output projection layers. (Default: ``True``)
+ has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding.
+ Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``)
+ num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``)
+ max_distance (int, optional): Naximum distance for relative position embedding. (Default: ``128``)
+ gru_rel_pos (bool, optional): If ``True``, apply gated relative position embedding. (Default: ``False``)
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ bias: bool = True,
+ has_relative_attention_bias: bool = False,
+ num_buckets: int = 32,
+ max_distance: int = 128,
+ gru_rel_pos: bool = True,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.has_relative_attention_bias = has_relative_attention_bias
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+
+ if has_relative_attention_bias:
+ self.rel_attn_embed = nn.Embedding(num_buckets, num_heads)
+ else:
+ self.rel_attn_embed = None
+
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+
+ self.dropout = dropout
+ self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True)
+
+ self.gru_rel_pos = gru_rel_pos
+ if self.gru_rel_pos:
+ self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
+ self.gru_rel_pos_const = nn.Parameter(torch.ones(1, num_heads, 1, 1))
+ self.has_position_bias = True
+
+ def compute_bias(self, query_length: int, key_length: int) -> Tensor:
+ """Compute relative position embeddings for WavLM model.
+ Args:
+ query_length (int): Query position can take values between 0 and ``query_length - 1``.
+ key_length (int): Key position can take values between 0 and ``key_length - 1``.
+ Returns:
+ Tensor of shape `(num_heads, query_length, key_length)`, relative positions embeddings
+ """
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
+ relative_position = memory_position - context_position # Shape (query_length, key_length)
+ relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
+ relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device)
+ values = self.rel_attn_embed(relative_position_bucket) # Shape (query_length, key_length, num_heads)
+ values = values.permute([2, 0, 1])
+ return values
+
+ def _relative_positions_bucket(self, relative_positions: Tensor, bidirectional: bool = True):
+ """Compute relative position buckets for WavLM model. Computation similar to formula (5) in WavLM
+ paper :cite:`chen2022wavlm`.
+ Args:
+ relative_positions (Tensor): Relative offsets between query and key positions,
+ of shape ``(query_length, key_length)``.
+ bidirectional (bool): If ``True``, values will be filled both above and below the diagonal in the resulting
+ matrix. If ``False``, the elements above the diagonal (i.e. with negative relative offsets) will be set
+ to zero. (Default ``True``)
+ Returns:
+ Tensor of shape ``(query_length, key_length)`` filled bucketed values of with relative positions.
+ """
+ num_buckets = self.num_buckets
+ max_distance = self.max_distance
+ # Shape (query_length, key_length)
+ relative_buckets = torch.zeros_like(relative_positions, dtype=torch.long)
+
+ if bidirectional:
+ num_buckets = num_buckets // 2
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
+ relative_positions = torch.abs(relative_positions)
+ else:
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
+
+ max_exact = num_buckets // 2
+ is_small = relative_positions < max_exact
+
+ relative_postion_if_large = max_exact + (
+ torch.log(relative_positions.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.long)
+ relative_postion_if_large = torch.min(
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
+ )
+
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
+ return relative_buckets
+
+ def forward(
+ self,
+ query: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ attention_mask: Optional[Tensor] = None,
+ position_bias: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ """
+ Args:
+ query (Tensor): Input of shape ``(batch_size, src_len, embed_dim)``.
+ key_padding_mask (Tensor or None, optional): Mask to exclude keys that are pads, of shape
+ `(batch, src_len)`, where padding elements are indicated by 1s. (Default: ``None``)
+ attn_mask: Needs to be ``None``. The argument exists for compatibility with
+ ``EncoderLayer``. (Default: ``None``)
+ position_bias (Tensor or None, optional): Position bias of shape
+ ``(batch_size * num_heads, src_len, src_len)``. When used inside WavLM model encoder, will be
+ generated in the first layer and then passed from each encoder layer to the next one.
+ (Default: ``None``)
+ Returns:
+ attn_output (Tensor): Attention output of shape ``(batch_size, src_len, embed_dim)``.
+ position_bias (Tensor or None): Position bias of shape ``(batch_size * num_heads, src_len, src_len)``.
+ """
+ bsz, seq_len, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert attention_mask is None
+
+ if self.rel_attn_embed is not None and position_bias is None:
+ position_bias = self.compute_bias(seq_len, seq_len)
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1)
+
+ attn_mask_rel_pos: Optional[Tensor] = None
+ if position_bias is not None:
+ attn_mask_rel_pos = position_bias
+ if self.gru_rel_pos: # Apply gating on relative position bias
+ query_layer = query.view(bsz, seq_len, self.num_heads, -1)
+ query_layer = query_layer.permute(0, 2, 1, 3)
+
+ gate_a, gate_b = torch.sigmoid(
+ self.gru_rel_pos_linear(query_layer).view(bsz, self.num_heads, seq_len, 2, 4).sum(-1, keepdim=False)
+ ).chunk(2, dim=-1)
+ gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
+ attn_mask_rel_pos = gate_a_1.view(bsz, self.num_heads, -1, 1) * position_bias
+
+ attn_mask_rel_pos = attn_mask_rel_pos.view((bsz, self.num_heads, seq_len, seq_len))
+
+ if attn_mask_rel_pos is not None and key_padding_mask is not None:
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1)
+ key_padding_mask = torch.nn.functional._canonical_mask(
+ mask=key_padding_mask,
+ mask_name="key_padding_mask",
+ other_type=torch.nn.functional._none_or_dtype(attn_mask_rel_pos),
+ other_name="",
+ target_type=query.dtype,
+ )
+ if attn_mask_rel_pos is not None and key_padding_mask is not None:
+ attn_mask_rel_pos = attn_mask_rel_pos + key_padding_mask
+ query_projected = torch.nn.functional.linear(query, self.attention.in_proj_weight, self.attention.in_proj_bias)
+ query, key, value = query_projected.chunk(3, -1)
+ shape = (bsz, seq_len, self.num_heads, self.head_dim)
+ query = query.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
+ key = key.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
+ value = value.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
+ dropout = self.dropout if self.training else 0.0
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ attn_mask=attn_mask_rel_pos,
+ dropout_p=dropout,
+ is_causal=False,
+ )
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, -1, self.num_heads * self.head_dim)
+ attn_output = self.attention.out_proj(attn_output)
+ return attn_output, position_bias
diff --git a/MLPY/Lib/site-packages/torchaudio/models/wavernn.py b/MLPY/Lib/site-packages/torchaudio/models/wavernn.py
new file mode 100644
index 0000000000000000000000000000000000000000..90bc2fca7240235e2a8e67ba454ba29a4a9e667b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/models/wavernn.py
@@ -0,0 +1,409 @@
+import math
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+__all__ = [
+ "ResBlock",
+ "MelResNet",
+ "Stretch2d",
+ "UpsampleNetwork",
+ "WaveRNN",
+]
+
+
+class ResBlock(nn.Module):
+ r"""ResNet block based on *Efficient Neural Audio Synthesis* :cite:`kalchbrenner2018efficient`.
+
+ Args:
+ n_freq: the number of bins in a spectrogram. (Default: ``128``)
+
+ Examples
+ >>> resblock = ResBlock()
+ >>> input = torch.rand(10, 128, 512) # a random spectrogram
+ >>> output = resblock(input) # shape: (10, 128, 512)
+ """
+
+ def __init__(self, n_freq: int = 128) -> None:
+ super().__init__()
+
+ self.resblock_model = nn.Sequential(
+ nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
+ nn.BatchNorm1d(n_freq),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
+ nn.BatchNorm1d(n_freq),
+ )
+
+ def forward(self, specgram: Tensor) -> Tensor:
+ r"""Pass the input through the ResBlock layer.
+ Args:
+ specgram (Tensor): the input sequence to the ResBlock layer (n_batch, n_freq, n_time).
+
+ Return:
+ Tensor shape: (n_batch, n_freq, n_time)
+ """
+
+ return self.resblock_model(specgram) + specgram
+
+
+class MelResNet(nn.Module):
+ r"""MelResNet layer uses a stack of ResBlocks on spectrogram.
+
+ Args:
+ n_res_block: the number of ResBlock in stack. (Default: ``10``)
+ n_freq: the number of bins in a spectrogram. (Default: ``128``)
+ n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
+ n_output: the number of output dimensions of melresnet. (Default: ``128``)
+ kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
+
+ Examples
+ >>> melresnet = MelResNet()
+ >>> input = torch.rand(10, 128, 512) # a random spectrogram
+ >>> output = melresnet(input) # shape: (10, 128, 508)
+ """
+
+ def __init__(
+ self, n_res_block: int = 10, n_freq: int = 128, n_hidden: int = 128, n_output: int = 128, kernel_size: int = 5
+ ) -> None:
+ super().__init__()
+
+ ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)]
+
+ self.melresnet_model = nn.Sequential(
+ nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False),
+ nn.BatchNorm1d(n_hidden),
+ nn.ReLU(inplace=True),
+ *ResBlocks,
+ nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1),
+ )
+
+ def forward(self, specgram: Tensor) -> Tensor:
+ r"""Pass the input through the MelResNet layer.
+ Args:
+ specgram (Tensor): the input sequence to the MelResNet layer (n_batch, n_freq, n_time).
+
+ Return:
+ Tensor shape: (n_batch, n_output, n_time - kernel_size + 1)
+ """
+
+ return self.melresnet_model(specgram)
+
+
+class Stretch2d(nn.Module):
+ r"""Upscale the frequency and time dimensions of a spectrogram.
+
+ Args:
+ time_scale: the scale factor in time dimension
+ freq_scale: the scale factor in frequency dimension
+
+ Examples
+ >>> stretch2d = Stretch2d(time_scale=10, freq_scale=5)
+
+ >>> input = torch.rand(10, 100, 512) # a random spectrogram
+ >>> output = stretch2d(input) # shape: (10, 500, 5120)
+ """
+
+ def __init__(self, time_scale: int, freq_scale: int) -> None:
+ super().__init__()
+
+ self.freq_scale = freq_scale
+ self.time_scale = time_scale
+
+ def forward(self, specgram: Tensor) -> Tensor:
+ r"""Pass the input through the Stretch2d layer.
+
+ Args:
+ specgram (Tensor): the input sequence to the Stretch2d layer (..., n_freq, n_time).
+
+ Return:
+ Tensor shape: (..., n_freq * freq_scale, n_time * time_scale)
+ """
+
+ return specgram.repeat_interleave(self.freq_scale, -2).repeat_interleave(self.time_scale, -1)
+
+
+class UpsampleNetwork(nn.Module):
+ r"""Upscale the dimensions of a spectrogram.
+
+ Args:
+ upsample_scales: the list of upsample scales.
+ n_res_block: the number of ResBlock in stack. (Default: ``10``)
+ n_freq: the number of bins in a spectrogram. (Default: ``128``)
+ n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
+ n_output: the number of output dimensions of melresnet. (Default: ``128``)
+ kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
+
+ Examples
+ >>> upsamplenetwork = UpsampleNetwork(upsample_scales=[4, 4, 16])
+ >>> input = torch.rand(10, 128, 10) # a random spectrogram
+ >>> output = upsamplenetwork(input) # shape: (10, 128, 1536), (10, 128, 1536)
+ """
+
+ def __init__(
+ self,
+ upsample_scales: List[int],
+ n_res_block: int = 10,
+ n_freq: int = 128,
+ n_hidden: int = 128,
+ n_output: int = 128,
+ kernel_size: int = 5,
+ ) -> None:
+ super().__init__()
+
+ total_scale = 1
+ for upsample_scale in upsample_scales:
+ total_scale *= upsample_scale
+ self.total_scale: int = total_scale
+
+ self.indent = (kernel_size - 1) // 2 * total_scale
+ self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
+ self.resnet_stretch = Stretch2d(total_scale, 1)
+
+ up_layers = []
+ for scale in upsample_scales:
+ stretch = Stretch2d(scale, 1)
+ conv = nn.Conv2d(
+ in_channels=1, out_channels=1, kernel_size=(1, scale * 2 + 1), padding=(0, scale), bias=False
+ )
+ torch.nn.init.constant_(conv.weight, 1.0 / (scale * 2 + 1))
+ up_layers.append(stretch)
+ up_layers.append(conv)
+ self.upsample_layers = nn.Sequential(*up_layers)
+
+ def forward(self, specgram: Tensor) -> Tuple[Tensor, Tensor]:
+ r"""Pass the input through the UpsampleNetwork layer.
+
+ Args:
+ specgram (Tensor): the input sequence to the UpsampleNetwork layer (n_batch, n_freq, n_time)
+
+ Return:
+ Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale),
+ (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
+ where total_scale is the product of all elements in upsample_scales.
+ """
+
+ resnet_output = self.resnet(specgram).unsqueeze(1)
+ resnet_output = self.resnet_stretch(resnet_output)
+ resnet_output = resnet_output.squeeze(1)
+
+ specgram = specgram.unsqueeze(1)
+ upsampling_output = self.upsample_layers(specgram)
+ upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent : -self.indent]
+
+ return upsampling_output, resnet_output
+
+
+class WaveRNN(nn.Module):
+ r"""WaveRNN model from *Efficient Neural Audio Synthesis* :cite:`wavernn`
+ based on the implementation from `fatchord/WaveRNN `_.
+
+ The original implementation was introduced in *Efficient Neural Audio Synthesis*
+ :cite:`kalchbrenner2018efficient`. The input channels of waveform and spectrogram have to be 1.
+ The product of `upsample_scales` must equal `hop_length`.
+
+ See Also:
+ * `Training example `__
+ * :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model.
+
+ Args:
+ upsample_scales: the list of upsample scales.
+ n_classes: the number of output classes.
+ hop_length: the number of samples between the starts of consecutive frames.
+ n_res_block: the number of ResBlock in stack. (Default: ``10``)
+ n_rnn: the dimension of RNN layer. (Default: ``512``)
+ n_fc: the dimension of fully connected layer. (Default: ``512``)
+ kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
+ n_freq: the number of bins in a spectrogram. (Default: ``128``)
+ n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
+ n_output: the number of output dimensions of melresnet. (Default: ``128``)
+
+ Example
+ >>> wavernn = WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200)
+ >>> waveform, sample_rate = torchaudio.load(file)
+ >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
+ >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
+ >>> output = wavernn(waveform, specgram)
+ >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes)
+ """
+
+ def __init__(
+ self,
+ upsample_scales: List[int],
+ n_classes: int,
+ hop_length: int,
+ n_res_block: int = 10,
+ n_rnn: int = 512,
+ n_fc: int = 512,
+ kernel_size: int = 5,
+ n_freq: int = 128,
+ n_hidden: int = 128,
+ n_output: int = 128,
+ ) -> None:
+ super().__init__()
+
+ self.kernel_size = kernel_size
+ self._pad = (kernel_size - 1 if kernel_size % 2 else kernel_size) // 2
+ self.n_rnn = n_rnn
+ self.n_aux = n_output // 4
+ self.hop_length = hop_length
+ self.n_classes = n_classes
+ self.n_bits: int = int(math.log2(self.n_classes))
+
+ total_scale = 1
+ for upsample_scale in upsample_scales:
+ total_scale *= upsample_scale
+ if total_scale != self.hop_length:
+ raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")
+
+ self.upsample = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
+ self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)
+
+ self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
+ self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True)
+
+ self.relu1 = nn.ReLU(inplace=True)
+ self.relu2 = nn.ReLU(inplace=True)
+
+ self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc)
+ self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc)
+ self.fc3 = nn.Linear(n_fc, self.n_classes)
+
+ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
+ r"""Pass the input through the WaveRNN model.
+
+ Args:
+ waveform: the input waveform to the WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
+ specgram: the input spectrogram to the WaveRNN layer (n_batch, 1, n_freq, n_time)
+
+ Return:
+ Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
+ """
+
+ if waveform.size(1) != 1:
+ raise ValueError("Require the input channel of waveform is 1")
+ if specgram.size(1) != 1:
+ raise ValueError("Require the input channel of specgram is 1")
+ # remove channel dimension until the end
+ waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)
+
+ batch_size = waveform.size(0)
+ h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
+ h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
+ # output of upsample:
+ # specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale)
+ # aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
+ specgram, aux = self.upsample(specgram)
+ specgram = specgram.transpose(1, 2)
+ aux = aux.transpose(1, 2)
+
+ aux_idx = [self.n_aux * i for i in range(5)]
+ a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
+ a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
+ a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
+ a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
+
+ x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1)
+ x = self.fc(x)
+ res = x
+ x, _ = self.rnn1(x, h1)
+
+ x = x + res
+ res = x
+ x = torch.cat([x, a2], dim=-1)
+ x, _ = self.rnn2(x, h2)
+
+ x = x + res
+ x = torch.cat([x, a3], dim=-1)
+ x = self.fc1(x)
+ x = self.relu1(x)
+
+ x = torch.cat([x, a4], dim=-1)
+ x = self.fc2(x)
+ x = self.relu2(x)
+ x = self.fc3(x)
+
+ # bring back channel dimension
+ return x.unsqueeze(1)
+
+ @torch.jit.export
+ def infer(self, specgram: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
+ r"""Inference method of WaveRNN.
+
+ This function currently only supports multinomial sampling, which assumes the
+ network is trained on cross entropy loss.
+
+ Args:
+ specgram (Tensor):
+ Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`.
+ lengths (Tensor or None, optional):
+ Indicates the valid length of each audio in the batch.
+ Shape: `(batch, )`.
+ When the ``specgram`` contains spectrograms with different durations,
+ by providing ``lengths`` argument, the model will compute
+ the corresponding valid output lengths.
+ If ``None``, it is assumed that all the audio in ``waveforms``
+ have valid length. Default: ``None``.
+
+ Returns:
+ (Tensor, Optional[Tensor]):
+ Tensor
+ The inferred waveform of size `(n_batch, 1, n_time)`.
+ 1 stands for a single channel.
+ Tensor or None
+ If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
+ is returned.
+ It indicates the valid length in time axis of the output Tensor.
+ """
+
+ device = specgram.device
+ dtype = specgram.dtype
+
+ specgram = torch.nn.functional.pad(specgram, (self._pad, self._pad))
+ specgram, aux = self.upsample(specgram)
+ if lengths is not None:
+ lengths = lengths * self.upsample.total_scale
+
+ output: List[Tensor] = []
+ b_size, _, seq_len = specgram.size()
+
+ h1 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
+ h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
+ x = torch.zeros((b_size, 1), device=device, dtype=dtype)
+
+ aux_split = [aux[:, self.n_aux * i : self.n_aux * (i + 1), :] for i in range(4)]
+
+ for i in range(seq_len):
+
+ m_t = specgram[:, :, i]
+
+ a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split]
+
+ x = torch.cat([x, m_t, a1_t], dim=1)
+ x = self.fc(x)
+ _, h1 = self.rnn1(x.unsqueeze(1), h1)
+
+ x = x + h1[0]
+ inp = torch.cat([x, a2_t], dim=1)
+ _, h2 = self.rnn2(inp.unsqueeze(1), h2)
+
+ x = x + h2[0]
+ x = torch.cat([x, a3_t], dim=1)
+ x = F.relu(self.fc1(x))
+
+ x = torch.cat([x, a4_t], dim=1)
+ x = F.relu(self.fc2(x))
+
+ logits = self.fc3(x)
+
+ posterior = F.softmax(logits, dim=1)
+
+ x = torch.multinomial(posterior, 1).float()
+ # Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1]
+ x = 2 * x / (2**self.n_bits - 1.0) - 1.0
+
+ output.append(x)
+
+ return torch.stack(output).permute(1, 2, 0), lengths
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/__init__.py b/MLPY/Lib/site-packages/torchaudio/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cad3f14dfb5b27839d0954959428732acedb8a2a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/pipelines/__init__.py
@@ -0,0 +1,102 @@
+from ._source_separation_pipeline import (
+ CONVTASNET_BASE_LIBRI2MIX,
+ HDEMUCS_HIGH_MUSDB,
+ HDEMUCS_HIGH_MUSDB_PLUS,
+ SourceSeparationBundle,
+)
+from ._squim_pipeline import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE, SquimObjectiveBundle, SquimSubjectiveBundle
+from ._tts import (
+ TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
+ TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
+ TACOTRON2_WAVERNN_CHAR_LJSPEECH,
+ TACOTRON2_WAVERNN_PHONE_LJSPEECH,
+ Tacotron2TTSBundle,
+)
+from ._wav2vec2.impl import (
+ HUBERT_ASR_LARGE,
+ HUBERT_ASR_XLARGE,
+ HUBERT_BASE,
+ HUBERT_LARGE,
+ HUBERT_XLARGE,
+ MMS_FA,
+ VOXPOPULI_ASR_BASE_10K_DE,
+ VOXPOPULI_ASR_BASE_10K_EN,
+ VOXPOPULI_ASR_BASE_10K_ES,
+ VOXPOPULI_ASR_BASE_10K_FR,
+ VOXPOPULI_ASR_BASE_10K_IT,
+ WAV2VEC2_ASR_BASE_100H,
+ WAV2VEC2_ASR_BASE_10M,
+ WAV2VEC2_ASR_BASE_960H,
+ WAV2VEC2_ASR_LARGE_100H,
+ WAV2VEC2_ASR_LARGE_10M,
+ WAV2VEC2_ASR_LARGE_960H,
+ WAV2VEC2_ASR_LARGE_LV60K_100H,
+ WAV2VEC2_ASR_LARGE_LV60K_10M,
+ WAV2VEC2_ASR_LARGE_LV60K_960H,
+ WAV2VEC2_BASE,
+ WAV2VEC2_LARGE,
+ WAV2VEC2_LARGE_LV60K,
+ WAV2VEC2_XLSR53,
+ WAV2VEC2_XLSR_1B,
+ WAV2VEC2_XLSR_2B,
+ WAV2VEC2_XLSR_300M,
+ Wav2Vec2ASRBundle,
+ Wav2Vec2Bundle,
+ Wav2Vec2FABundle,
+ WAVLM_BASE,
+ WAVLM_BASE_PLUS,
+ WAVLM_LARGE,
+)
+from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
+
+
+__all__ = [
+ "Wav2Vec2Bundle",
+ "Wav2Vec2ASRBundle",
+ "Wav2Vec2FABundle",
+ "WAV2VEC2_BASE",
+ "WAV2VEC2_LARGE",
+ "WAV2VEC2_LARGE_LV60K",
+ "WAV2VEC2_ASR_BASE_10M",
+ "WAV2VEC2_ASR_BASE_100H",
+ "WAV2VEC2_ASR_BASE_960H",
+ "WAV2VEC2_ASR_LARGE_10M",
+ "WAV2VEC2_ASR_LARGE_100H",
+ "WAV2VEC2_ASR_LARGE_960H",
+ "WAV2VEC2_ASR_LARGE_LV60K_10M",
+ "WAV2VEC2_ASR_LARGE_LV60K_100H",
+ "WAV2VEC2_ASR_LARGE_LV60K_960H",
+ "WAV2VEC2_XLSR53",
+ "WAV2VEC2_XLSR_300M",
+ "WAV2VEC2_XLSR_1B",
+ "WAV2VEC2_XLSR_2B",
+ "VOXPOPULI_ASR_BASE_10K_EN",
+ "VOXPOPULI_ASR_BASE_10K_ES",
+ "VOXPOPULI_ASR_BASE_10K_DE",
+ "VOXPOPULI_ASR_BASE_10K_FR",
+ "VOXPOPULI_ASR_BASE_10K_IT",
+ "HUBERT_BASE",
+ "HUBERT_LARGE",
+ "HUBERT_XLARGE",
+ "HUBERT_ASR_LARGE",
+ "HUBERT_ASR_XLARGE",
+ "MMS_FA",
+ "WAVLM_BASE",
+ "WAVLM_BASE_PLUS",
+ "WAVLM_LARGE",
+ "Tacotron2TTSBundle",
+ "TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH",
+ "TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH",
+ "TACOTRON2_WAVERNN_CHAR_LJSPEECH",
+ "TACOTRON2_WAVERNN_PHONE_LJSPEECH",
+ "RNNTBundle",
+ "EMFORMER_RNNT_BASE_LIBRISPEECH",
+ "SourceSeparationBundle",
+ "CONVTASNET_BASE_LIBRI2MIX",
+ "HDEMUCS_HIGH_MUSDB_PLUS",
+ "HDEMUCS_HIGH_MUSDB",
+ "SQUIM_OBJECTIVE",
+ "SQUIM_SUBJECTIVE",
+ "SquimObjectiveBundle",
+ "SquimSubjectiveBundle",
+]
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/pipelines/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6426d3923bda9fd1f427e3aba4cc112437804252
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/pipelines/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/__pycache__/_source_separation_pipeline.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/pipelines/__pycache__/_source_separation_pipeline.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..717d4f7fd5e84f2f3b69dae81aeafc675c9ce855
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/pipelines/__pycache__/_source_separation_pipeline.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/__pycache__/_squim_pipeline.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/pipelines/__pycache__/_squim_pipeline.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3315c06d4f3a73de18a707f14f5941caa4fa521
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/pipelines/__pycache__/_squim_pipeline.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/__pycache__/rnnt_pipeline.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/pipelines/__pycache__/rnnt_pipeline.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48ebcda54b4b997cca261400ffbc269ec9050132
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/pipelines/__pycache__/rnnt_pipeline.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_source_separation_pipeline.py b/MLPY/Lib/site-packages/torchaudio/pipelines/_source_separation_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..eebf190bd2233bd65143ba4b8b0da0ba5f1c6eba
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/pipelines/_source_separation_pipeline.py
@@ -0,0 +1,109 @@
+from dataclasses import dataclass
+from functools import partial
+from typing import Callable
+
+import torch
+import torchaudio
+
+from torchaudio.models import conv_tasnet_base, hdemucs_high
+
+
+@dataclass
+class SourceSeparationBundle:
+ """Dataclass that bundles components for performing source separation.
+
+ Example
+ >>> import torchaudio
+ >>> from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX
+ >>> import torch
+ >>>
+ >>> # Build the separation model.
+ >>> model = CONVTASNET_BASE_LIBRI2MIX.get_model()
+ >>> 100%|███████████████████████████████|19.1M/19.1M [00:04<00:00, 4.93MB/s]
+ >>>
+ >>> # Instantiate the test set of Libri2Mix dataset.
+ >>> dataset = torchaudio.datasets.LibriMix("/home/datasets/", subset="test")
+ >>>
+ >>> # Apply source separation on mixture audio.
+ >>> for i, data in enumerate(dataset):
+ >>> sample_rate, mixture, clean_sources = data
+ >>> # Make sure the shape of input suits the model requirement.
+ >>> mixture = mixture.reshape(1, 1, -1)
+ >>> estimated_sources = model(mixture)
+ >>> score = si_snr_pit(estimated_sources, clean_sources) # for demonstration
+ >>> print(f"Si-SNR score is : {score}.)
+ >>> break
+ >>> Si-SNR score is : 16.24.
+ >>>
+ """
+
+ _model_path: str
+ _model_factory_func: Callable[[], torch.nn.Module]
+ _sample_rate: int
+
+ @property
+ def sample_rate(self) -> int:
+ """Sample rate of the audio that the model is trained on.
+
+ :type: int
+ """
+ return self._sample_rate
+
+ def get_model(self) -> torch.nn.Module:
+ """Construct the model and load the pretrained weight."""
+ model = self._model_factory_func()
+ path = torchaudio.utils.download_asset(self._model_path)
+ state_dict = torch.load(path)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+
+CONVTASNET_BASE_LIBRI2MIX = SourceSeparationBundle(
+ _model_path="models/conv_tasnet_base_libri2mix.pt",
+ _model_factory_func=partial(conv_tasnet_base, num_sources=2),
+ _sample_rate=8000,
+)
+CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained Source Separation pipeline with *ConvTasNet*
+:cite:`Luo_2019` trained on *Libri2Mix dataset* :cite:`cosentino2020librimix`.
+
+The source separation model is constructed by :func:`~torchaudio.models.conv_tasnet_base`
+and is trained using the training script ``lightning_train.py``
+`here `__
+with default arguments.
+
+Please refer to :class:`SourceSeparationBundle` for usage instructions.
+"""
+
+
+HDEMUCS_HIGH_MUSDB_PLUS = SourceSeparationBundle(
+ _model_path="models/hdemucs_high_trained.pt",
+ _model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
+ _sample_rate=44100,
+)
+HDEMUCS_HIGH_MUSDB_PLUS.__doc__ = """Pre-trained music source separation pipeline with
+*Hybrid Demucs* :cite:`defossez2021hybrid` trained on both training and test sets of
+MUSDB-HQ :cite:`MUSDB18HQ` and an additional 150 extra songs from an internal database
+that was specifically produced for Meta.
+
+The model is constructed by :func:`~torchaudio.models.hdemucs_high`.
+
+Training was performed in the original HDemucs repository `here `__.
+
+Please refer to :class:`SourceSeparationBundle` for usage instructions.
+"""
+
+
+HDEMUCS_HIGH_MUSDB = SourceSeparationBundle(
+ _model_path="models/hdemucs_high_musdbhq_only.pt",
+ _model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
+ _sample_rate=44100,
+)
+HDEMUCS_HIGH_MUSDB.__doc__ = """Pre-trained music source separation pipeline with
+*Hybrid Demucs* :cite:`defossez2021hybrid` trained on the training set of MUSDB-HQ :cite:`MUSDB18HQ`.
+
+The model is constructed by :func:`~torchaudio.models.hdemucs_high`.
+Training was performed in the original HDemucs repository `here `__.
+
+Please refer to :class:`SourceSeparationBundle` for usage instructions.
+"""
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_squim_pipeline.py b/MLPY/Lib/site-packages/torchaudio/pipelines/_squim_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..a80731cd3ba5488a785fd99f9cbe5025b63b046a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/pipelines/_squim_pipeline.py
@@ -0,0 +1,176 @@
+from dataclasses import dataclass
+
+from torchaudio._internal import load_state_dict_from_url
+
+from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective
+
+
+@dataclass
+class SquimObjectiveBundle:
+ """Data class that bundles associated information to use pretrained
+ :py:class:`~torchaudio.models.SquimObjective` model.
+
+ This class provides interfaces for instantiating the pretrained model along with
+ the information necessary to retrieve pretrained weights and additional data
+ to be used with the model.
+
+ Torchaudio library instantiates objects of this class, each of which represents
+ a different pretrained model. Client code should access pretrained models via these
+ instances.
+
+ This bundle can estimate objective metric scores for speech enhancement, such as STOI, PESQ, Si-SDR.
+ A typical use case would be a flow like `waveform -> list of scores`. Please see below for the code example.
+
+ Example: Estimate the objective metric scores for the input waveform.
+ >>> import torch
+ >>> import torchaudio
+ >>> from torchaudio.pipelines import SQUIM_OBJECTIVE as bundle
+ >>>
+ >>> # Load the SquimObjective bundle
+ >>> model = bundle.get_model()
+ Downloading: "https://download.pytorch.org/torchaudio/models/squim_objective_dns2020.pth"
+ 100%|████████████| 28.2M/28.2M [00:03<00:00, 9.24MB/s]
+ >>>
+ >>> # Resample audio to the expected sampling rate
+ >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
+ >>>
+ >>> # Estimate objective metric scores
+ >>> scores = model(waveform)
+ >>> print(f"STOI: {scores[0].item()}, PESQ: {scores[1].item()}, SI-SDR: {scores[2].item()}.")
+ """ # noqa: E501
+
+ _path: str
+ _sample_rate: float
+
+ def _get_state_dict(self, dl_kwargs):
+ url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
+ return state_dict
+
+ def get_model(self, *, dl_kwargs=None) -> SquimObjective:
+ """Construct the SquimObjective model, and load the pretrained weight.
+
+ The weight file is downloaded from the internet and cached with
+ :func:`torch.hub.load_state_dict_from_url`
+
+ Args:
+ dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
+
+ Returns:
+ Variation of :py:class:`~torchaudio.models.SquimObjective`.
+ """
+ model = squim_objective_base()
+ model.load_state_dict(self._get_state_dict(dl_kwargs))
+ model.eval()
+ return model
+
+ @property
+ def sample_rate(self):
+ """Sample rate of the audio that the model is trained on.
+
+ :type: float
+ """
+ return self._sample_rate
+
+
+SQUIM_OBJECTIVE = SquimObjectiveBundle(
+ "squim_objective_dns2020.pth",
+ _sample_rate=16000,
+)
+SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
+ :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
+
+ The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
+ The weights are under `Creative Commons Attribution 4.0 International License
+ `__.
+
+ Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
+ """
+
+
+@dataclass
+class SquimSubjectiveBundle:
+ """Data class that bundles associated information to use pretrained
+ :py:class:`~torchaudio.models.SquimSubjective` model.
+
+ This class provides interfaces for instantiating the pretrained model along with
+ the information necessary to retrieve pretrained weights and additional data
+ to be used with the model.
+
+ Torchaudio library instantiates objects of this class, each of which represents
+ a different pretrained model. Client code should access pretrained models via these
+ instances.
+
+ This bundle can estimate subjective metric scores for speech enhancement, such as MOS.
+ A typical use case would be a flow like `waveform -> score`. Please see below for the code example.
+
+ Example: Estimate the subjective metric scores for the input waveform.
+ >>> import torch
+ >>> import torchaudio
+ >>> from torchaudio.pipelines import SQUIM_SUBJECTIVE as bundle
+ >>>
+ >>> # Load the SquimSubjective bundle
+ >>> model = bundle.get_model()
+ Downloading: "https://download.pytorch.org/torchaudio/models/squim_subjective_bvcc_daps.pth"
+ 100%|████████████| 360M/360M [00:09<00:00, 41.1MB/s]
+ >>>
+ >>> # Resample audio to the expected sampling rate
+ >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
+ >>> # Use a clean reference (doesn't need to be the reference for the waveform) as the second input
+ >>> reference = torchaudio.functional.resample(reference, sample_rate, bundle.sample_rate)
+ >>>
+ >>> # Estimate subjective metric scores
+ >>> score = model(waveform, reference)
+ >>> print(f"MOS: {score}.")
+ """ # noqa: E501
+
+ _path: str
+ _sample_rate: float
+
+ def _get_state_dict(self, dl_kwargs):
+ url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
+ return state_dict
+
+ def get_model(self, *, dl_kwargs=None) -> SquimSubjective:
+ """Construct the SquimSubjective model, and load the pretrained weight.
+
+ The weight file is downloaded from the internet and cached with
+ :func:`torch.hub.load_state_dict_from_url`
+
+ Args:
+ dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
+
+ Returns:
+ Variation of :py:class:`~torchaudio.models.SquimObjective`.
+ """
+ model = squim_subjective_base()
+ model.load_state_dict(self._get_state_dict(dl_kwargs))
+ model.eval()
+ return model
+
+ @property
+ def sample_rate(self):
+ """Sample rate of the audio that the model is trained on.
+
+ :type: float
+ """
+ return self._sample_rate
+
+
+SQUIM_SUBJECTIVE = SquimSubjectiveBundle(
+ "squim_subjective_bvcc_daps.pth",
+ _sample_rate=16000,
+)
+SQUIM_SUBJECTIVE.__doc__ = """SquimSubjective pipeline trained
+ as described in :cite:`manocha2022speech` and :cite:`kumar2023torchaudio`
+ on the *BVCC* :cite:`cooper2021voices` and *DAPS* :cite:`mysore2014can` datasets.
+
+ The underlying model is constructed by :py:func:`torchaudio.models.squim_subjective_base`.
+ The weights are under `Creative Commons Attribution Non Commercial 4.0 International
+ `__.
+
+ Please refer to :py:class:`SquimSubjectiveBundle` for usage instructions.
+ """
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/__init__.py b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..488d121d458f65454bab2719f873c10262e1aac9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/__init__.py
@@ -0,0 +1,16 @@
+from .impl import (
+ TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
+ TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
+ TACOTRON2_WAVERNN_CHAR_LJSPEECH,
+ TACOTRON2_WAVERNN_PHONE_LJSPEECH,
+)
+from .interface import Tacotron2TTSBundle
+
+
+__all__ = [
+ "Tacotron2TTSBundle",
+ "TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH",
+ "TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH",
+ "TACOTRON2_WAVERNN_CHAR_LJSPEECH",
+ "TACOTRON2_WAVERNN_PHONE_LJSPEECH",
+]
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7265bc9ec526d96728d1e2e7bb1baf6a39ab010
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/__pycache__/impl.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/__pycache__/impl.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42ac011d3e1832483af4c625d73f606ee7d21669
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/__pycache__/impl.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/__pycache__/interface.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/__pycache__/interface.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5e13047828bb790aba22f0ab9bbe15c03d40158
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/__pycache__/interface.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..00d2bd3f605e705cbf84b0de6c306dca769688d5
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/impl.py b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/impl.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b8ac89c4a940128e406a743d023b53835645c95
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/impl.py
@@ -0,0 +1,385 @@
+import re
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import Tensor
+from torchaudio._internal import load_state_dict_from_url
+from torchaudio.functional import mu_law_decoding
+from torchaudio.models import Tacotron2, WaveRNN
+from torchaudio.transforms import GriffinLim, InverseMelScale
+
+from . import utils
+from .interface import Tacotron2TTSBundle
+
+__all__ = []
+
+_BASE_URL = "https://download.pytorch.org/torchaudio/models"
+
+
+################################################################################
+# Pipeline implementation - Text Processor
+################################################################################
+
+
+class _EnglishCharProcessor(Tacotron2TTSBundle.TextProcessor):
+ def __init__(self):
+ super().__init__()
+ self._tokens = utils._get_chars()
+ self._mapping = {s: i for i, s in enumerate(self._tokens)}
+
+ @property
+ def tokens(self):
+ return self._tokens
+
+ def __call__(self, texts: Union[str, List[str]]) -> Tuple[Tensor, Tensor]:
+ if isinstance(texts, str):
+ texts = [texts]
+ indices = [[self._mapping[c] for c in t.lower() if c in self._mapping] for t in texts]
+ return utils._to_tensor(indices)
+
+
+class _EnglishPhoneProcessor(Tacotron2TTSBundle.TextProcessor):
+ def __init__(self, *, dl_kwargs=None):
+ super().__init__()
+ self._tokens = utils._get_phones()
+ self._mapping = {p: i for i, p in enumerate(self._tokens)}
+ self._phonemizer = utils._load_phonemizer("en_us_cmudict_forward.pt", dl_kwargs=dl_kwargs)
+ self._pattern = r"(\[[A-Z]+?\]|[_!'(),.:;? -])"
+
+ @property
+ def tokens(self):
+ return self._tokens
+
+ def __call__(self, texts: Union[str, List[str]]) -> Tuple[Tensor, Tensor]:
+ if isinstance(texts, str):
+ texts = [texts]
+
+ indices = []
+ for phones in self._phonemizer(texts, lang="en_us"):
+ # '[F][UW][B][AA][R]!' -> ['F', 'UW', 'B', 'AA', 'R', '!']
+ ret = [re.sub(r"[\[\]]", "", r) for r in re.findall(self._pattern, phones)]
+ indices.append([self._mapping[p] for p in ret])
+ return utils._to_tensor(indices)
+
+
+################################################################################
+# Pipeline implementation - Vocoder
+################################################################################
+
+
+class _WaveRNNVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
+ def __init__(self, model: WaveRNN, min_level_db: Optional[float] = -100):
+ super().__init__()
+ self._sample_rate = 22050
+ self._model = model
+ self._min_level_db = min_level_db
+
+ @property
+ def sample_rate(self):
+ return self._sample_rate
+
+ def forward(self, mel_spec, lengths=None):
+ mel_spec = torch.exp(mel_spec)
+ mel_spec = 20 * torch.log10(torch.clamp(mel_spec, min=1e-5))
+ if self._min_level_db is not None:
+ mel_spec = (self._min_level_db - mel_spec) / self._min_level_db
+ mel_spec = torch.clamp(mel_spec, min=0, max=1)
+ waveform, lengths = self._model.infer(mel_spec, lengths)
+ waveform = utils._unnormalize_waveform(waveform, self._model.n_bits)
+ waveform = mu_law_decoding(waveform, self._model.n_classes)
+ waveform = waveform.squeeze(1)
+ return waveform, lengths
+
+
+class _GriffinLimVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
+ def __init__(self):
+ super().__init__()
+ self._sample_rate = 22050
+ self._inv_mel = InverseMelScale(
+ n_stft=(1024 // 2 + 1),
+ n_mels=80,
+ sample_rate=self.sample_rate,
+ f_min=0.0,
+ f_max=8000.0,
+ mel_scale="slaney",
+ norm="slaney",
+ )
+ self._griffin_lim = GriffinLim(
+ n_fft=1024,
+ power=1,
+ hop_length=256,
+ win_length=1024,
+ )
+
+ @property
+ def sample_rate(self):
+ return self._sample_rate
+
+ def forward(self, mel_spec, lengths=None):
+ mel_spec = torch.exp(mel_spec)
+ mel_spec = mel_spec.clone().detach().requires_grad_(True)
+ spec = self._inv_mel(mel_spec)
+ spec = spec.detach().requires_grad_(False)
+ waveforms = self._griffin_lim(spec)
+ return waveforms, lengths
+
+
+################################################################################
+# Bundle classes mixins
+################################################################################
+
+
+class _CharMixin:
+ def get_text_processor(self) -> Tacotron2TTSBundle.TextProcessor:
+ return _EnglishCharProcessor()
+
+
+class _PhoneMixin:
+ def get_text_processor(self, *, dl_kwargs=None) -> Tacotron2TTSBundle.TextProcessor:
+ return _EnglishPhoneProcessor(dl_kwargs=dl_kwargs)
+
+
+@dataclass
+class _Tacotron2Mixin:
+ _tacotron2_path: str
+ _tacotron2_params: Dict[str, Any]
+
+ def get_tacotron2(self, *, dl_kwargs=None) -> Tacotron2:
+ model = Tacotron2(**self._tacotron2_params)
+ url = f"{_BASE_URL}/{self._tacotron2_path}"
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+
+@dataclass
+class _WaveRNNMixin:
+ _wavernn_path: Optional[str]
+ _wavernn_params: Optional[Dict[str, Any]]
+
+ def get_vocoder(self, *, dl_kwargs=None):
+ wavernn = self._get_wavernn(dl_kwargs=dl_kwargs)
+ return _WaveRNNVocoder(wavernn)
+
+ def _get_wavernn(self, *, dl_kwargs=None):
+ model = WaveRNN(**self._wavernn_params)
+ url = f"{_BASE_URL}/{self._wavernn_path}"
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+
+class _GriffinLimMixin:
+ def get_vocoder(self, **_):
+ return _GriffinLimVocoder()
+
+
+################################################################################
+# Bundle classes
+################################################################################
+
+
+@dataclass
+class _Tacotron2WaveRNNCharBundle(_WaveRNNMixin, _Tacotron2Mixin, _CharMixin, Tacotron2TTSBundle):
+ pass
+
+
+@dataclass
+class _Tacotron2WaveRNNPhoneBundle(_WaveRNNMixin, _Tacotron2Mixin, _PhoneMixin, Tacotron2TTSBundle):
+ pass
+
+
+@dataclass
+class _Tacotron2GriffinLimCharBundle(_GriffinLimMixin, _Tacotron2Mixin, _CharMixin, Tacotron2TTSBundle):
+ pass
+
+
+@dataclass
+class _Tacotron2GriffinLimPhoneBundle(_GriffinLimMixin, _Tacotron2Mixin, _PhoneMixin, Tacotron2TTSBundle):
+ pass
+
+
+################################################################################
+# Instantiate bundle objects
+################################################################################
+
+
+TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH = _Tacotron2GriffinLimCharBundle(
+ _tacotron2_path="tacotron2_english_characters_1500_epochs_ljspeech.pth",
+ _tacotron2_params=utils._get_taco_params(n_symbols=38),
+)
+TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.__doc__ = """Character-based TTS pipeline with :py:class:`~torchaudio.models.Tacotron2` trained on *LJSpeech* :cite:`ljspeech17` for 1,500 epochs, and
+:py:class:`~torchaudio.transforms.GriffinLim` as vocoder.
+
+The text processor encodes the input texts character-by-character.
+
+You can find the training script `here `__.
+The default parameters were used.
+
+Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
+
+Example - "Hello world! T T S stands for Text to Speech!"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+
+Example - "The examination and testimony of the experts enabled the Commission to conclude that five shots may have been fired,"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH_v2.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+""" # noqa: E501
+
+TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH = _Tacotron2GriffinLimPhoneBundle(
+ _tacotron2_path="tacotron2_english_phonemes_1500_epochs_ljspeech.pth",
+ _tacotron2_params=utils._get_taco_params(n_symbols=96),
+)
+TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.__doc__ = """Phoneme-based TTS pipeline with :py:class:`~torchaudio.models.Tacotron2` trained on *LJSpeech* :cite:`ljspeech17` for 1,500 epochs and
+:py:class:`~torchaudio.transforms.GriffinLim` as vocoder.
+
+The text processor encodes the input texts based on phoneme.
+It uses `DeepPhonemizer `__ to convert
+graphemes to phonemes.
+The model (*en_us_cmudict_forward*) was trained on
+`CMUDict `__.
+
+You can find the training script `here `__.
+The text processor is set to the *"english_phonemes"*.
+
+Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
+
+Example - "Hello world! T T S stands for Text to Speech!"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+
+Example - "The examination and testimony of the experts enabled the Commission to conclude that five shots may have been fired,"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH_v2.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+
+""" # noqa: E501
+
+TACOTRON2_WAVERNN_CHAR_LJSPEECH = _Tacotron2WaveRNNCharBundle(
+ _tacotron2_path="tacotron2_english_characters_1500_epochs_wavernn_ljspeech.pth",
+ _tacotron2_params=utils._get_taco_params(n_symbols=38),
+ _wavernn_path="wavernn_10k_epochs_8bits_ljspeech.pth",
+ _wavernn_params=utils._get_wrnn_params(),
+)
+TACOTRON2_WAVERNN_CHAR_LJSPEECH.__doc__ = """Character-based TTS pipeline with :py:class:`~torchaudio.models.Tacotron2` trained on *LJSpeech* :cite:`ljspeech17` for 1,500 epochs and :py:class:`~torchaudio.models.WaveRNN` vocoder trained on 8 bits depth waveform of *LJSpeech* :cite:`ljspeech17` for 10,000 epochs.
+
+The text processor encodes the input texts character-by-character.
+
+You can find the training script `here `__.
+The following parameters were used; ``win_length=1100``, ``hop_length=275``, ``n_fft=2048``,
+``mel_fmin=40``, and ``mel_fmax=11025``.
+
+You can find the training script `here `__.
+
+Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
+
+Example - "Hello world! T T S stands for Text to Speech!"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_CHAR_LJSPEECH.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+
+Example - "The examination and testimony of the experts enabled the Commission to conclude that five shots may have been fired,"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_CHAR_LJSPEECH_v2.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+""" # noqa: E501
+
+TACOTRON2_WAVERNN_PHONE_LJSPEECH = _Tacotron2WaveRNNPhoneBundle(
+ _tacotron2_path="tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth",
+ _tacotron2_params=utils._get_taco_params(n_symbols=96),
+ _wavernn_path="wavernn_10k_epochs_8bits_ljspeech.pth",
+ _wavernn_params=utils._get_wrnn_params(),
+)
+TACOTRON2_WAVERNN_PHONE_LJSPEECH.__doc__ = """Phoneme-based TTS pipeline with :py:class:`~torchaudio.models.Tacotron2` trained on *LJSpeech* :cite:`ljspeech17` for 1,500 epochs, and
+:py:class:`~torchaudio.models.WaveRNN` vocoder trained on 8 bits depth waveform of *LJSpeech* :cite:`ljspeech17` for 10,000 epochs.
+
+The text processor encodes the input texts based on phoneme.
+It uses `DeepPhonemizer `__ to convert
+graphemes to phonemes.
+The model (*en_us_cmudict_forward*) was trained on
+`CMUDict `__.
+
+You can find the training script for Tacotron2 `here `__.
+The following parameters were used; ``win_length=1100``, ``hop_length=275``, ``n_fft=2048``,
+``mel_fmin=40``, and ``mel_fmax=11025``.
+
+You can find the training script for WaveRNN `here `__.
+
+Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
+
+Example - "Hello world! T T S stands for Text to Speech!"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_PHONE_LJSPEECH.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+
+
+Example - "The examination and testimony of the experts enabled the Commission to conclude that five shots may have been fired,"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_PHONE_LJSPEECH_v2.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+""" # noqa: E501
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/interface.py b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..273dfca2b14877cebc7cdb0716d60440693a775e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/interface.py
@@ -0,0 +1,255 @@
+from abc import ABC, abstractmethod
+from typing import List, Optional, Tuple, Union
+
+from torch import Tensor
+from torchaudio.models import Tacotron2
+
+
+class _TextProcessor(ABC):
+ @property
+ @abstractmethod
+ def tokens(self):
+ """The tokens that the each value in the processed tensor represent.
+
+ :type: List[str]
+ """
+
+ @abstractmethod
+ def __call__(self, texts: Union[str, List[str]]) -> Tuple[Tensor, Tensor]:
+ """Encode the given (batch of) texts into numerical tensors
+
+ Args:
+ text (str or list of str): The input texts.
+
+ Returns:
+ (Tensor, Tensor):
+ Tensor:
+ The encoded texts. Shape: `(batch, max length)`
+ Tensor:
+ The valid length of each sample in the batch. Shape: `(batch, )`.
+ """
+
+
+class _Vocoder(ABC):
+ @property
+ @abstractmethod
+ def sample_rate(self):
+ """The sample rate of the resulting waveform
+
+ :type: float
+ """
+
+ @abstractmethod
+ def __call__(self, specgrams: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
+ """Generate waveform from the given input, such as spectrogram
+
+ Args:
+ specgrams (Tensor):
+ The input spectrogram. Shape: `(batch, frequency bins, time)`.
+ The expected shape depends on the implementation.
+ lengths (Tensor, or None, optional):
+ The valid length of each sample in the batch. Shape: `(batch, )`.
+ (Default: `None`)
+
+ Returns:
+ (Tensor, Optional[Tensor]):
+ Tensor:
+ The generated waveform. Shape: `(batch, max length)`
+ Tensor or None:
+ The valid length of each sample in the batch. Shape: `(batch, )`.
+ """
+
+
+class Tacotron2TTSBundle(ABC):
+ """Data class that bundles associated information to use pretrained Tacotron2 and vocoder.
+
+ This class provides interfaces for instantiating the pretrained model along with
+ the information necessary to retrieve pretrained weights and additional data
+ to be used with the model.
+
+ Torchaudio library instantiates objects of this class, each of which represents
+ a different pretrained model. Client code should access pretrained models via these
+ instances.
+
+ Please see below for the usage and the available values.
+
+ Example - Character-based TTS pipeline with Tacotron2 and WaveRNN
+ >>> import torchaudio
+ >>>
+ >>> text = "Hello, T T S !"
+ >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH
+ >>>
+ >>> # Build processor, Tacotron2 and WaveRNN model
+ >>> processor = bundle.get_text_processor()
+ >>> tacotron2 = bundle.get_tacotron2()
+ Downloading:
+ 100%|███████████████████████████████| 107M/107M [00:01<00:00, 87.9MB/s]
+ >>> vocoder = bundle.get_vocoder()
+ Downloading:
+ 100%|███████████████████████████████| 16.7M/16.7M [00:00<00:00, 78.1MB/s]
+ >>>
+ >>> # Encode text
+ >>> input, lengths = processor(text)
+ >>>
+ >>> # Generate (mel-scale) spectrogram
+ >>> specgram, lengths, _ = tacotron2.infer(input, lengths)
+ >>>
+ >>> # Convert spectrogram to waveform
+ >>> waveforms, lengths = vocoder(specgram, lengths)
+ >>>
+ >>> torchaudio.save('hello-tts.wav', waveforms, vocoder.sample_rate)
+
+ Example - Phoneme-based TTS pipeline with Tacotron2 and WaveRNN
+ >>>
+ >>> # Note:
+ >>> # This bundle uses pre-trained DeepPhonemizer as
+ >>> # the text pre-processor.
+ >>> # Please install deep-phonemizer.
+ >>> # See https://github.com/as-ideas/DeepPhonemizer
+ >>> # The pretrained weight is automatically downloaded.
+ >>>
+ >>> import torchaudio
+ >>>
+ >>> text = "Hello, TTS!"
+ >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
+ >>>
+ >>> # Build processor, Tacotron2 and WaveRNN model
+ >>> processor = bundle.get_text_processor()
+ Downloading:
+ 100%|███████████████████████████████| 63.6M/63.6M [00:04<00:00, 15.3MB/s]
+ >>> tacotron2 = bundle.get_tacotron2()
+ Downloading:
+ 100%|███████████████████████████████| 107M/107M [00:01<00:00, 87.9MB/s]
+ >>> vocoder = bundle.get_vocoder()
+ Downloading:
+ 100%|███████████████████████████████| 16.7M/16.7M [00:00<00:00, 78.1MB/s]
+ >>>
+ >>> # Encode text
+ >>> input, lengths = processor(text)
+ >>>
+ >>> # Generate (mel-scale) spectrogram
+ >>> specgram, lengths, _ = tacotron2.infer(input, lengths)
+ >>>
+ >>> # Convert spectrogram to waveform
+ >>> waveforms, lengths = vocoder(specgram, lengths)
+ >>>
+ >>> torchaudio.save('hello-tts.wav', waveforms, vocoder.sample_rate)
+ """
+
+ # Using the inner class so that these interfaces are not directly exposed on
+ # `torchaudio.pipelines`, but still listed in documentation.
+ # The thing is, text processing and vocoder are generic and we do not know what kind of
+ # new text processing and vocoder will be added in the future, so we want to make these
+ # interfaces specific to this Tacotron2TTS pipeline.
+
+ class TextProcessor(_TextProcessor):
+ """Interface of the text processing part of Tacotron2TTS pipeline
+
+ See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_text_processor` for the usage.
+ """
+
+ class Vocoder(_Vocoder):
+ """Interface of the vocoder part of Tacotron2TTS pipeline
+
+ See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_vocoder` for the usage.
+ """
+
+ @abstractmethod
+ def get_text_processor(self, *, dl_kwargs=None) -> TextProcessor:
+ """Create a text processor
+
+ For character-based pipeline, this processor splits the input text by character.
+ For phoneme-based pipeline, this processor converts the input text (grapheme) to
+ phonemes.
+
+ If a pre-trained weight file is necessary,
+ :func:`torch.hub.download_url_to_file` is used to downloaded it.
+
+ Args:
+ dl_kwargs (dictionary of keyword arguments,):
+ Passed to :func:`torch.hub.download_url_to_file`.
+
+ Returns:
+ TextProcessor:
+ A callable which takes a string or a list of strings as input and
+ returns Tensor of encoded texts and Tensor of valid lengths.
+ The object also has ``tokens`` property, which allows to recover the
+ tokenized form.
+
+ Example - Character-based
+ >>> text = [
+ >>> "Hello World!",
+ >>> "Text-to-speech!",
+ >>> ]
+ >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH
+ >>> processor = bundle.get_text_processor()
+ >>> input, lengths = processor(text)
+ >>>
+ >>> print(input)
+ tensor([[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2, 0, 0, 0],
+ [31, 16, 35, 31, 1, 31, 26, 1, 30, 27, 16, 16, 14, 19, 2]],
+ dtype=torch.int32)
+ >>>
+ >>> print(lengths)
+ tensor([12, 15], dtype=torch.int32)
+ >>>
+ >>> print([processor.tokens[i] for i in input[0, :lengths[0]]])
+ ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!']
+ >>> print([processor.tokens[i] for i in input[1, :lengths[1]]])
+ ['t', 'e', 'x', 't', '-', 't', 'o', '-', 's', 'p', 'e', 'e', 'c', 'h', '!']
+
+ Example - Phoneme-based
+ >>> text = [
+ >>> "Hello, T T S !",
+ >>> "Text-to-speech!",
+ >>> ]
+ >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
+ >>> processor = bundle.get_text_processor()
+ Downloading:
+ 100%|███████████████████████████████| 63.6M/63.6M [00:04<00:00, 15.3MB/s]
+ >>> input, lengths = processor(text)
+ >>>
+ >>> print(input)
+ tensor([[54, 20, 65, 69, 11, 92, 44, 65, 38, 2, 0, 0, 0, 0],
+ [81, 40, 64, 79, 81, 1, 81, 20, 1, 79, 77, 59, 37, 2]],
+ dtype=torch.int32)
+ >>>
+ >>> print(lengths)
+ tensor([10, 14], dtype=torch.int32)
+ >>>
+ >>> print([processor.tokens[i] for i in input[0]])
+ ['HH', 'AH', 'L', 'OW', ' ', 'W', 'ER', 'L', 'D', '!', '_', '_', '_', '_']
+ >>> print([processor.tokens[i] for i in input[1]])
+ ['T', 'EH', 'K', 'S', 'T', '-', 'T', 'AH', '-', 'S', 'P', 'IY', 'CH', '!']
+ """
+
+ @abstractmethod
+ def get_vocoder(self, *, dl_kwargs=None) -> Vocoder:
+ """Create a vocoder module, based off of either WaveRNN or GriffinLim.
+
+ If a pre-trained weight file is necessary,
+ :func:`torch.hub.load_state_dict_from_url` is used to downloaded it.
+
+ Args:
+ dl_kwargs (dictionary of keyword arguments):
+ Passed to :func:`torch.hub.load_state_dict_from_url`.
+
+ Returns:
+ Vocoder:
+ A vocoder module, which takes spectrogram Tensor and an optional
+ length Tensor, then returns resulting waveform Tensor and an optional
+ length Tensor.
+ """
+
+ @abstractmethod
+ def get_tacotron2(self, *, dl_kwargs=None) -> Tacotron2:
+ """Create a Tacotron2 model with pre-trained weight.
+
+ Args:
+ dl_kwargs (dictionary of keyword arguments):
+ Passed to :func:`torch.hub.load_state_dict_from_url`.
+
+ Returns:
+ Tacotron2:
+ The resulting model.
+ """
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/utils.py b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c94c21ec519b92647033a81d1bb026e5296ffc64
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/pipelines/_tts/utils.py
@@ -0,0 +1,228 @@
+import logging
+import os
+
+import torch
+from torchaudio._internal import download_url_to_file, module_utils as _mod_utils
+
+
+def _get_chars():
+ return (
+ "_",
+ "-",
+ "!",
+ "'",
+ "(",
+ ")",
+ ",",
+ ".",
+ ":",
+ ";",
+ "?",
+ " ",
+ "a",
+ "b",
+ "c",
+ "d",
+ "e",
+ "f",
+ "g",
+ "h",
+ "i",
+ "j",
+ "k",
+ "l",
+ "m",
+ "n",
+ "o",
+ "p",
+ "q",
+ "r",
+ "s",
+ "t",
+ "u",
+ "v",
+ "w",
+ "x",
+ "y",
+ "z",
+ )
+
+
+def _get_phones():
+ return (
+ "_",
+ "-",
+ "!",
+ "'",
+ "(",
+ ")",
+ ",",
+ ".",
+ ":",
+ ";",
+ "?",
+ " ",
+ "AA",
+ "AA0",
+ "AA1",
+ "AA2",
+ "AE",
+ "AE0",
+ "AE1",
+ "AE2",
+ "AH",
+ "AH0",
+ "AH1",
+ "AH2",
+ "AO",
+ "AO0",
+ "AO1",
+ "AO2",
+ "AW",
+ "AW0",
+ "AW1",
+ "AW2",
+ "AY",
+ "AY0",
+ "AY1",
+ "AY2",
+ "B",
+ "CH",
+ "D",
+ "DH",
+ "EH",
+ "EH0",
+ "EH1",
+ "EH2",
+ "ER",
+ "ER0",
+ "ER1",
+ "ER2",
+ "EY",
+ "EY0",
+ "EY1",
+ "EY2",
+ "F",
+ "G",
+ "HH",
+ "IH",
+ "IH0",
+ "IH1",
+ "IH2",
+ "IY",
+ "IY0",
+ "IY1",
+ "IY2",
+ "JH",
+ "K",
+ "L",
+ "M",
+ "N",
+ "NG",
+ "OW",
+ "OW0",
+ "OW1",
+ "OW2",
+ "OY",
+ "OY0",
+ "OY1",
+ "OY2",
+ "P",
+ "R",
+ "S",
+ "SH",
+ "T",
+ "TH",
+ "UH",
+ "UH0",
+ "UH1",
+ "UH2",
+ "UW",
+ "UW0",
+ "UW1",
+ "UW2",
+ "V",
+ "W",
+ "Y",
+ "Z",
+ "ZH",
+ )
+
+
+def _to_tensor(indices):
+ lengths = torch.tensor([len(i) for i in indices], dtype=torch.int32)
+ values = [torch.tensor(i) for i in indices]
+ values = torch.nn.utils.rnn.pad_sequence(values, batch_first=True)
+ return values, lengths
+
+
+def _load_phonemizer(file, dl_kwargs):
+ if not _mod_utils.is_module_available("dp"):
+ raise RuntimeError("DeepPhonemizer is not installed. Please install it.")
+
+ from dp.phonemizer import Phonemizer
+
+ # By default, dp issues DEBUG level log.
+ logger = logging.getLogger("dp")
+ orig_level = logger.level
+ logger.setLevel(logging.INFO)
+ try:
+ url = f"https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/{file}"
+ directory = os.path.join(torch.hub.get_dir(), "checkpoints")
+ os.makedirs(directory, exist_ok=True)
+ path = os.path.join(directory, file)
+ if not os.path.exists(path):
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
+ download_url_to_file(url, path, **dl_kwargs)
+ return Phonemizer.from_checkpoint(path)
+ finally:
+ logger.setLevel(orig_level)
+
+
+def _unnormalize_waveform(waveform: torch.Tensor, bits: int) -> torch.Tensor:
+ r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1]"""
+ waveform = torch.clamp(waveform, -1, 1)
+ waveform = (waveform + 1.0) * (2**bits - 1) / 2
+ return torch.clamp(waveform, 0, 2**bits - 1).int()
+
+
+def _get_taco_params(n_symbols):
+ return {
+ "mask_padding": False,
+ "n_mels": 80,
+ "n_frames_per_step": 1,
+ "symbol_embedding_dim": 512,
+ "encoder_embedding_dim": 512,
+ "encoder_n_convolution": 3,
+ "encoder_kernel_size": 5,
+ "decoder_rnn_dim": 1024,
+ "decoder_max_step": 2000,
+ "decoder_dropout": 0.1,
+ "decoder_early_stopping": True,
+ "attention_rnn_dim": 1024,
+ "attention_hidden_dim": 128,
+ "attention_location_n_filter": 32,
+ "attention_location_kernel_size": 31,
+ "attention_dropout": 0.1,
+ "prenet_dim": 256,
+ "postnet_n_convolution": 5,
+ "postnet_kernel_size": 5,
+ "postnet_embedding_dim": 512,
+ "gate_threshold": 0.5,
+ "n_symbol": n_symbols,
+ }
+
+
+def _get_wrnn_params():
+ return {
+ "upsample_scales": [5, 5, 11],
+ "n_classes": 2**8, # n_bits = 8
+ "hop_length": 275,
+ "n_res_block": 10,
+ "n_rnn": 512,
+ "n_fc": 512,
+ "kernel_size": 5,
+ "n_freq": 80,
+ "n_hidden": 128,
+ "n_output": 128,
+ }
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/__init__.py b/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ec01f0f710ea00f24249785e2fdeee491fa6609
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/aligner.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/aligner.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b29ff49ea646691eba282c444d4d2df3c7b92dc
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/aligner.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/impl.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/impl.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4af01ae6561006745b8544e3fd854c4a9ec3fda7
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/impl.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ef4384c6503a414ed2df437de5bf742af30c051
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/aligner.py b/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/aligner.py
new file mode 100644
index 0000000000000000000000000000000000000000..f23b9cf65c733d39f13524171474f324666e22dd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/aligner.py
@@ -0,0 +1,87 @@
+from abc import ABC, abstractmethod
+from typing import Dict, List
+
+import torch
+import torchaudio.functional as F
+from torch import Tensor
+from torchaudio.functional import TokenSpan
+
+
+class ITokenizer(ABC):
+ @abstractmethod
+ def __call__(self, transcript: List[str]) -> List[List[str]]:
+ """Tokenize the given transcript (list of word)
+
+ .. note::
+
+ The toranscript must be normalized.
+
+ Args:
+ transcript (list of str): Transcript (list of word).
+
+ Returns:
+ (list of int): List of token sequences
+ """
+
+
+class Tokenizer(ITokenizer):
+ def __init__(self, dictionary: Dict[str, int]):
+ self.dictionary = dictionary
+
+ def __call__(self, transcript: List[str]) -> List[List[int]]:
+ return [[self.dictionary[c] for c in word] for word in transcript]
+
+
+def _align_emission_and_tokens(emission: Tensor, tokens: List[int], blank: int = 0):
+ device = emission.device
+ emission = emission.unsqueeze(0)
+ targets = torch.tensor([tokens], dtype=torch.int32, device=device)
+
+ aligned_tokens, scores = F.forced_align(emission, targets, blank=blank)
+
+ scores = scores.exp() # convert back to probability
+ aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension
+ return aligned_tokens, scores
+
+
+class IAligner(ABC):
+ @abstractmethod
+ def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
+ """Generate list of time-stamped token sequences
+
+ Args:
+ emission (Tensor): Sequence of token probability distributions in log-domain.
+ Shape: `(time, tokens)`.
+ tokens (list of integer sequence): Tokenized transcript.
+ Output from :py:class:`torchaudio.pipelines.Wav2Vec2FABundle.Tokenizer`.
+
+ Returns:
+ (list of TokenSpan sequence): Tokens with time stamps and scores.
+ """
+
+
+def _unflatten(list_, lengths):
+ assert len(list_) == sum(lengths)
+ i = 0
+ ret = []
+ for l in lengths:
+ ret.append(list_[i : i + l])
+ i += l
+ return ret
+
+
+def _flatten(nested_list):
+ return [item for list_ in nested_list for item in list_]
+
+
+class Aligner(IAligner):
+ def __init__(self, blank):
+ self.blank = blank
+
+ def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
+ if emission.ndim != 2:
+ raise ValueError(f"The input emission must be 2D. Found: {emission.shape}")
+
+ aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens), self.blank)
+ spans = F.merge_tokens(aligned_tokens, scores)
+ return _unflatten(spans, [len(ts) for ts in tokens])
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/impl.py b/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/impl.py
new file mode 100644
index 0000000000000000000000000000000000000000..be21da436024275dae50e5b7fd22e351ab9b8e5d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/impl.py
@@ -0,0 +1,1699 @@
+from dataclasses import dataclass
+from typing import Any, Dict, Optional, Tuple
+
+from torch.nn import Module
+
+from . import aligner, utils
+
+
+__all__ = [] # type: ignore
+
+
+@dataclass
+class Wav2Vec2Bundle:
+ """Data class that bundles associated information to use pretrained :py:class:`~torchaudio.models.Wav2Vec2Model`.
+
+ This class provides interfaces for instantiating the pretrained model along with
+ the information necessary to retrieve pretrained weights and additional data
+ to be used with the model.
+
+ Torchaudio library instantiates objects of this class, each of which represents
+ a different pretrained model. Client code should access pretrained models via these
+ instances.
+
+ Please see below for the usage and the available values.
+
+ Example - Feature Extraction
+ >>> import torchaudio
+ >>>
+ >>> bundle = torchaudio.pipelines.HUBERT_BASE
+ >>>
+ >>> # Build the model and load pretrained weight.
+ >>> model = bundle.get_model()
+ Downloading:
+ 100%|███████████████████████████████| 360M/360M [00:06<00:00, 60.6MB/s]
+ >>>
+ >>> # Resample audio to the expected sampling rate
+ >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
+ >>>
+ >>> # Extract acoustic features
+ >>> features, _ = model.extract_features(waveform)
+ """ # noqa: E501
+
+ _path: str
+ _params: Dict[str, Any]
+ _sample_rate: float
+ _normalize_waveform: bool
+ _model_type: str
+
+ @property
+ def sample_rate(self) -> float:
+ """Sample rate of the audio that the model is trained on.
+
+ :type: float
+ """
+ return self._sample_rate
+
+ def _get_state_dict(self, dl_kwargs):
+ # Note: This method is overridden in ASR bundle
+ return utils._get_state_dict(self._path, dl_kwargs)
+
+ def get_model(self, *, dl_kwargs=None) -> Module:
+ """Construct the model and load the pretrained weight.
+
+ The weight file is downloaded from the internet and cached with
+ :func:`torch.hub.load_state_dict_from_url`
+
+ Args:
+ dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
+
+ Returns:
+ Variation of :py:class:`~torchaudio.models.Wav2Vec2Model`.
+
+ For the models listed below, an additional layer normalization is performed on the input.
+
+ For all other models, a :py:class:`~torchaudio.models.Wav2Vec2Model` instance is returned.
+
+ - WAV2VEC2_LARGE_LV60K
+ - WAV2VEC2_ASR_LARGE_LV60K_10M
+ - WAV2VEC2_ASR_LARGE_LV60K_100H
+ - WAV2VEC2_ASR_LARGE_LV60K_960H
+ - WAV2VEC2_XLSR53
+ - WAV2VEC2_XLSR_300M
+ - WAV2VEC2_XLSR_1B
+ - WAV2VEC2_XLSR_2B
+ - HUBERT_LARGE
+ - HUBERT_XLARGE
+ - HUBERT_ASR_LARGE
+ - HUBERT_ASR_XLARGE
+ - WAVLM_LARGE
+ """
+ model = utils._get_model(self._model_type, self._params)
+ state_dict = self._get_state_dict(dl_kwargs)
+ model.load_state_dict(state_dict)
+ if self._normalize_waveform:
+ model = utils._extend_model(model, normalize_waveform=True)
+ model.eval()
+ return model
+
+
+@dataclass
+class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
+ """Data class that bundles associated information to use pretrained
+ :py:class:`~torchaudio.models.Wav2Vec2Model`.
+
+ This class provides interfaces for instantiating the pretrained model along with
+ the information necessary to retrieve pretrained weights and additional data
+ to be used with the model.
+
+ Torchaudio library instantiates objects of this class, each of which represents
+ a different pretrained model. Client code should access pretrained models via these
+ instances.
+
+ Please see below for the usage and the available values.
+
+ Example - ASR
+ >>> import torchaudio
+ >>>
+ >>> bundle = torchaudio.pipelines.HUBERT_ASR_LARGE
+ >>>
+ >>> # Build the model and load pretrained weight.
+ >>> model = bundle.get_model()
+ Downloading:
+ 100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s]
+ >>>
+ >>> # Check the corresponding labels of the output.
+ >>> labels = bundle.get_labels()
+ >>> print(labels)
+ ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
+ >>>
+ >>> # Resample audio to the expected sampling rate
+ >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
+ >>>
+ >>> # Infer the label probability distribution
+ >>> emissions, _ = model(waveform)
+ >>>
+ >>> # Pass emission to decoder
+ >>> # `ctc_decode` is for illustration purpose only
+ >>> transcripts = ctc_decode(emissions, labels)
+ """ # noqa: E501
+
+ _labels: Tuple[str, ...]
+ _remove_aux_axis: Tuple[int, ...] = (1, 2, 3)
+
+ def get_labels(
+ self,
+ *,
+ blank: str = "-",
+ ) -> Tuple[str, ...]:
+ """The output class labels.
+
+ The first is blank token, and it is customizable.
+
+ Args:
+ blank (str, optional): Blank token. (default: ``'-'``)
+
+ Returns:
+ Tuple[str, ...]:
+ For models fine-tuned on ASR, returns the tuple of strings representing
+ the output class labels.
+
+ Example
+ >>> from torchaudio.pipelines import HUBERT_ASR_LARGE as bundle
+ >>> bundle.get_labels()
+ ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
+ """ # noqa: E501
+ return (blank, *self._labels)
+
+ def _get_state_dict(self, dl_kwargs):
+ return utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis)
+
+
+WAV2VEC2_BASE = Wav2Vec2Bundle(
+ _path="wav2vec2_fairseq_base_ls960.pth",
+ _params={
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": None,
+ },
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_BASE.__doc__ = """Wav2vec 2.0 model ("base" architecture),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), not fine-tuned.
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
+ _path="wav2vec2_fairseq_base_ls960_asr_ll10m.pth",
+ _params={
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_BASE_10M.__doc__ = """Wav2vec 2.0 model ("base" architecture with an extra linear module),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and
+fine-tuned for ASR on 10 minutes of transcribed audio from *Libri-Light* dataset
+:cite:`librilight` ("train-10min" subset).
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_base_ls960_asr_ls100.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+
+WAV2VEC2_ASR_BASE_100H.__doc__ = """Wav2vec 2.0 model ("base" architecture with an extra linear module),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and
+fine-tuned for ASR on 100 hours of transcribed audio from "train-clean-100" subset.
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_base_ls960_asr_ls960.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_BASE_960H.__doc__ = """Wav2vec 2.0 model ("base" architecture with an extra linear module),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and
+fine-tuned for ASR on the same audio with the corresponding transcripts.
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_LARGE = Wav2Vec2Bundle(
+ "wav2vec2_fairseq_large_ls960.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.2,
+ "aux_num_out": None,
+ },
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_LARGE.__doc__ = """Wav2vec 2.0 model ("large" architecture),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), not fine-tuned.
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_large_ls960_asr_ll10m.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.2,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_LARGE_10M.__doc__ = """Wav2vec 2.0 model ("large" architecture with an extra linear module),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and
+fine-tuned for ASR on 10 minutes of transcribed audio from *Libri-Light* dataset
+:cite:`librilight` ("train-10min" subset).
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_large_ls960_asr_ls100.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.2,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_LARGE_100H.__doc__ = """Wav2vec 2.0 model ("large" architecture with an extra linear module),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and
+fine-tuned for ASR on 100 hours of transcribed audio from
+the same dataset ("train-clean-100" subset).
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_large_ls960_asr_ls960.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.2,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_LARGE_960H.__doc__ = """Wav2vec 2.0 model ("large" architecture with an extra linear module),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and
+fine-tuned for ASR on the same audio with the corresponding transcripts.
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_LARGE_LV60K = Wav2Vec2Bundle(
+ "wav2vec2_fairseq_large_lv60k.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": None,
+ },
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_LARGE_LV60K.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture),
+pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`,
+not fine-tuned.
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_large_lv60k_asr_ll10m.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_LARGE_LV60K_10M.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture with an extra linear module),
+pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`, and
+fine-tuned for ASR on 10 minutes of transcribed audio from the same dataset ("train-10min" subset).
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_large_lv60k_asr_ls100.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_LARGE_LV60K_100H.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture with an extra linear module),
+pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`, and
+fine-tuned for ASR on 100 hours of transcribed audio from
+*LibriSpeech* dataset :cite:`7178964` ("train-clean-100" subset).
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_large_lv60k_asr_ls960.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_LARGE_LV60K_960H.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture with an extra linear module),
+pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* :cite:`librilight` dataset, and
+fine-tuned for ASR on 960 hours of transcribed audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500").
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_XLSR53 = Wav2Vec2Bundle(
+ "wav2vec2_fairseq_large_xlsr53.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": None,
+ },
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_XLSR53.__doc__ = """Wav2vec 2.0 model ("base" architecture),
+pre-trained on 56,000 hours of unlabeled audio from multiple datasets (
+*Multilingual LibriSpeech* :cite:`Pratap_2020`,
+*CommonVoice* :cite:`ardila2020common` and
+*BABEL* :cite:`Gales2014SpeechRA`),
+not fine-tuned.
+
+Originally published by the authors of
+*Unsupervised Cross-lingual Representation Learning for Speech Recognition*
+:cite:`conneau2020unsupervised` under MIT License and redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+HUBERT_BASE = Wav2Vec2Bundle(
+ "hubert_fairseq_base_ls960.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": None,
+ },
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+HUBERT_BASE.__doc__ = """HuBERT model ("base" architecture),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), not fine-tuned.
+
+Originally published by the authors of *HuBERT* :cite:`hsu2021hubert` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+HUBERT_LARGE = Wav2Vec2Bundle(
+ "hubert_fairseq_large_ll60k.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": None,
+ },
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+HUBERT_LARGE.__doc__ = """HuBERT model ("large" architecture),
+pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`,
+not fine-tuned.
+
+Originally published by the authors of *HuBERT* :cite:`hsu2021hubert` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+HUBERT_XLARGE = Wav2Vec2Bundle(
+ "hubert_fairseq_xlarge_ll60k.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1280,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 48,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 5120,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": None,
+ },
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+HUBERT_XLARGE.__doc__ = """HuBERT model ("extra large" architecture),
+pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`,
+not fine-tuned.
+
+Originally published by the authors of *HuBERT* :cite:`hsu2021hubert` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
+ "hubert_fairseq_large_ll60k_asr_ls960.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+HUBERT_ASR_LARGE.__doc__ = """HuBERT model ("large" architecture),
+pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`, and
+fine-tuned for ASR on 960 hours of transcribed audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500").
+
+Originally published by the authors of *HuBERT* :cite:`hsu2021hubert` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
+ "hubert_fairseq_xlarge_ll60k_asr_ls960.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1280,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 48,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 5120,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+HUBERT_ASR_XLARGE.__doc__ = """HuBERT model ("extra large" architecture),
+pre-trained on 60,000 hours of unlabeled audio from
+*Libri-Light* dataset :cite:`librilight`, and
+fine-tuned for ASR on 960 hours of transcribed audio from
+*LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500").
+
+Originally published by the authors of *HuBERT* :cite:`hsu2021hubert` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+
+VOXPOPULI_ASR_BASE_10K_DE = Wav2Vec2ASRBundle(
+ "wav2vec2_voxpopuli_base_10k_asr_de.pt",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 32,
+ },
+ _labels=utils._get_de_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _remove_aux_axis=(1, 2, 3, 35),
+ _model_type="Wav2Vec2",
+)
+VOXPOPULI_ASR_BASE_10K_DE.__doc__ = """wav2vec 2.0 model ("base" architecture),
+pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
+("10k" subset, consisting of 23 languages), and
+fine-tuned for ASR on 282 hours of transcribed audio from "de" subset.
+
+Originally published by the authors of *VoxPopuli* :cite:`voxpopuli` under CC BY-NC 4.0 and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+
+VOXPOPULI_ASR_BASE_10K_EN = Wav2Vec2ASRBundle(
+ "wav2vec2_voxpopuli_base_10k_asr_en.pt",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 28,
+ },
+ _labels=utils._get_vp_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _remove_aux_axis=(1, 2, 3, 31),
+ _model_type="Wav2Vec2",
+)
+VOXPOPULI_ASR_BASE_10K_EN.__doc__ = """wav2vec 2.0 model ("base" architecture),
+pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
+("10k" subset, consisting of 23 languages), and
+fine-tuned for ASR on 543 hours of transcribed audio from "en" subset.
+
+Originally published by the authors of *VoxPopuli* :cite:`voxpopuli` under CC BY-NC 4.0 and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+
+VOXPOPULI_ASR_BASE_10K_ES = Wav2Vec2ASRBundle(
+ "wav2vec2_voxpopuli_base_10k_asr_es.pt",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 35,
+ },
+ _labels=utils._get_es_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _remove_aux_axis=(1, 2, 3, 35),
+ _model_type="Wav2Vec2",
+)
+VOXPOPULI_ASR_BASE_10K_ES.__doc__ = """wav2vec 2.0 model ("base" architecture),
+pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
+("10k" subset, consisting of 23 languages), and
+fine-tuned for ASR on 166 hours of transcribed audio from "es" subset.
+
+Originally published by the authors of *VoxPopuli* :cite:`voxpopuli` under CC BY-NC 4.0 and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+VOXPOPULI_ASR_BASE_10K_FR = Wav2Vec2ASRBundle(
+ "wav2vec2_voxpopuli_base_10k_asr_fr.pt",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 43,
+ },
+ _labels=utils._get_fr_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+VOXPOPULI_ASR_BASE_10K_FR.__doc__ = """wav2vec 2.0 model ("base" architecture),
+pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
+("10k" subset, consisting of 23 languages), and
+fine-tuned for ASR on 211 hours of transcribed audio from "fr" subset.
+
+Originally published by the authors of *VoxPopuli* :cite:`voxpopuli` under CC BY-NC 4.0 and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+
+VOXPOPULI_ASR_BASE_10K_IT = Wav2Vec2ASRBundle(
+ "wav2vec2_voxpopuli_base_10k_asr_it.pt",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 37,
+ },
+ _labels=utils._get_it_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _remove_aux_axis=(1, 2, 3),
+ _model_type="Wav2Vec2",
+)
+VOXPOPULI_ASR_BASE_10K_IT.__doc__ = """wav2vec 2.0 model ("base" architecture),
+pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
+("10k" subset, consisting of 23 languages), and
+fine-tuned for ASR on 91 hours of transcribed audio from "it" subset.
+
+Originally published by the authors of *VoxPopuli* :cite:`voxpopuli` under CC BY-NC 4.0 and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+
+WAVLM_BASE = Wav2Vec2Bundle(
+ "wavlm_base.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_max_distance": 800,
+ "encoder_num_buckets": 320,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": None,
+ },
+ _model_type="WavLM",
+ _sample_rate=16000,
+ _normalize_waveform=False,
+)
+WAVLM_BASE.__doc__ = """WavLM Base model ("base" architecture),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`, not fine-tuned.
+
+Originally published by the authors of *WavLM* :cite:`chen2022wavlm` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+
+WAVLM_BASE_PLUS = Wav2Vec2Bundle(
+ "wavlm_base_plus.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_max_distance": 800,
+ "encoder_num_buckets": 320,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": None,
+ },
+ _model_type="WavLM",
+ _sample_rate=16000,
+ _normalize_waveform=False,
+)
+WAVLM_BASE_PLUS.__doc__ = """WavLM Base+ model ("base" architecture),
+pre-trained on 60,000 hours of Libri-Light dataset :cite:`librilight`, 10,000 hours of GigaSpeech :cite:`GigaSpeech2021`,
+and 24,000 hours of *VoxPopuli* :cite:`voxpopuli`, not fine-tuned.
+
+Originally published by the authors of *WavLM* :cite:`chen2022wavlm` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+
+WAVLM_LARGE = Wav2Vec2Bundle(
+ "wavlm_large.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_max_distance": 800,
+ "encoder_num_buckets": 320,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": None,
+ },
+ _model_type="WavLM",
+ _sample_rate=16000,
+ _normalize_waveform=True,
+)
+WAVLM_LARGE.__doc__ = """WavLM Large model ("large" architecture),
+pre-trained on 60,000 hours of Libri-Light dataset :cite:`librilight`, 10,000 hours of GigaSpeech :cite:`GigaSpeech2021`,
+and 24,000 hours of *VoxPopuli* :cite:`voxpopuli`, not fine-tuned.
+
+Originally published by the authors of *WavLM* :cite:`chen2022wavlm` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+
+WAV2VEC2_XLSR_300M = Wav2Vec2Bundle(
+ "wav2vec2_xlsr_300m.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": None,
+ },
+ _model_type="Wav2Vec2",
+ _sample_rate=16000,
+ _normalize_waveform=True,
+)
+WAV2VEC2_XLSR_300M.__doc__ = """XLS-R model with 300 million parameters,
+pre-trained on 436,000 hours of unlabeled audio from multiple datasets (
+*Multilingual LibriSpeech* :cite:`Pratap_2020`,
+*CommonVoice* :cite:`ardila2020common`,
+*VoxLingua107* :cite:`valk2021voxlingua107`,
+*BABEL* :cite:`Gales2014SpeechRA`, and
+*VoxPopuli* :cite:`voxpopuli`) in 128 languages,
+not fine-tuned.
+
+Originally published by the authors of *XLS-R* :cite:`babu2021xls` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for usage details.
+""" # noqa: E501
+
+
+WAV2VEC2_XLSR_1B = Wav2Vec2Bundle(
+ "wav2vec2_xlsr_1b.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1280,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 48,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 5120,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": None,
+ },
+ _model_type="Wav2Vec2",
+ _sample_rate=16000,
+ _normalize_waveform=True,
+)
+WAV2VEC2_XLSR_1B.__doc__ = """XLS-R model with 1 billion parameters,
+pre-trained on 436,000 hours of unlabeled audio from multiple datasets (
+*Multilingual LibriSpeech* :cite:`Pratap_2020`,
+*CommonVoice* :cite:`ardila2020common`,
+*VoxLingua107* :cite:`valk2021voxlingua107`,
+*BABEL* :cite:`Gales2014SpeechRA`, and
+*VoxPopuli* :cite:`voxpopuli`) in 128 languages,
+not fine-tuned.
+
+Originally published by the authors of *XLS-R* :cite:`babu2021xls` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for usage details.
+""" # noqa: E501
+
+WAV2VEC2_XLSR_2B = Wav2Vec2Bundle(
+ "wav2vec2_xlsr_2b.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1920,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 48,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 7680,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": None,
+ },
+ _model_type="Wav2Vec2",
+ _sample_rate=16000,
+ _normalize_waveform=True,
+)
+WAV2VEC2_XLSR_2B.__doc__ = """XLS-R model with 2 billion parameters,
+pre-trained on 436,000 hours of unlabeled audio from multiple datasets (
+*Multilingual LibriSpeech* :cite:`Pratap_2020`,
+*CommonVoice* :cite:`ardila2020common`,
+*VoxLingua107* :cite:`valk2021voxlingua107`,
+*BABEL* :cite:`Gales2014SpeechRA`, and
+*VoxPopuli* :cite:`voxpopuli`) in 128 languages,
+not fine-tuned.
+
+Originally published by the authors of *XLS-R* :cite:`babu2021xls` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for usage details.
+""" # noqa: E501
+
+
+@dataclass
+class Wav2Vec2FABundle(Wav2Vec2ASRBundle):
+ """Data class that bundles associated information to use pretrained :py:class:`~torchaudio.models.Wav2Vec2Model` for forced alignment.
+
+ This class provides interfaces for instantiating the pretrained model along with
+ the information necessary to retrieve pretrained weights and additional data
+ to be used with the model.
+
+ Torchaudio library instantiates objects of this class, each of which represents
+ a different pretrained model. Client code should access pretrained models via these
+ instances.
+
+ Please see below for the usage and the available values.
+
+ Example - Feature Extraction
+ >>> import torchaudio
+ >>>
+ >>> bundle = torchaudio.pipelines.MMS_FA
+ >>>
+ >>> # Build the model and load pretrained weight.
+ >>> model = bundle.get_model()
+ Downloading:
+ 100%|███████████████████████████████| 1.18G/1.18G [00:05<00:00, 216MB/s]
+ >>>
+ >>> # Resample audio to the expected sampling rate
+ >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
+ >>>
+ >>> # Estimate the probability of token distribution
+ >>> emission, _ = model(waveform)
+ >>>
+ >>> # Generate frame-wise alignment
+ >>> alignment, scores = torchaudio.functional.forced_align(
+ >>> emission, targets, input_lengths, target_lengths, blank=0)
+ >>>
+ """ # noqa: E501
+
+ class Tokenizer(aligner.ITokenizer):
+ """Interface of the tokenizer"""
+
+ class Aligner(aligner.IAligner):
+ """Interface of the aligner"""
+
+ def get_labels(self, star: Optional[str] = "*", blank: str = "-") -> Tuple[str, ...]:
+ """Get the labels corresponding to the feature dimension of emission.
+
+ The first is blank token, and it is customizable.
+
+ Args:
+ star (str or None, optional): Change or disable star token. (default: ``"*"``)
+ blank (str, optional): Change the blank token. (default: ``'-'``)
+
+ Returns:
+ Tuple[str, ...]:
+ For models fine-tuned on ASR, returns the tuple of strings representing
+ the output class labels.
+
+ Example
+ >>> from torchaudio.pipelines import MMS_FA as bundle
+ >>> bundle.get_labels()
+ ('-', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x', '*')
+ >>> bundle.get_labels(star=None)
+ ('-', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x')
+ """ # noqa: E501
+ labels = super().get_labels(blank=blank)
+ return labels if star is None else (*labels, star)
+
+ def get_model(self, with_star: bool = True, *, dl_kwargs=None) -> Module:
+ """Construct the model and load the pretrained weight.
+
+ The weight file is downloaded from the internet and cached with
+ :func:`torch.hub.load_state_dict_from_url`
+
+ Args:
+ with_star (bool, optional): If enabled, the last dimension of output layer is
+ extended by one, which corresponds to `star` token.
+ dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
+
+ Returns:
+ Variation of :py:class:`~torchaudio.models.Wav2Vec2Model`.
+
+ .. note::
+
+ The model created with this method returns probability in log-domain,
+ (i.e. :py:func:`torch.nn.functional.log_softmax` is applied), whereas
+ the other Wav2Vec2 models returns logit.
+ """
+ model = utils._get_model(self._model_type, self._params)
+ state_dict = utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis)
+ model.load_state_dict(state_dict)
+ model = utils._extend_model(
+ model, normalize_waveform=self._normalize_waveform, apply_log_softmax=True, append_star=with_star
+ )
+ model.eval()
+ return model
+
+ def get_dict(self, star: Optional[str] = "*", blank: str = "-") -> Dict[str, int]:
+ """Get the mapping from token to index (in emission feature dim)
+
+ Args:
+ star (str or None, optional): Change or disable star token. (default: ``"*"``)
+ blank (str, optional): Change the blank token. (default: ``'-'``)
+
+ Returns:
+ Tuple[str, ...]:
+ For models fine-tuned on ASR, returns the tuple of strings representing
+ the output class labels.
+
+ Example
+ >>> from torchaudio.pipelines import MMS_FA as bundle
+ >>> bundle.get_dict()
+ {'-': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27, '*': 28}
+ >>> bundle.get_dict(star=None)
+ {'-': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27}
+ """ # noqa: E501
+ return {k: i for i, k in enumerate(self.get_labels(star=star, blank=blank))}
+
+ def get_tokenizer(self) -> Tokenizer:
+ """Instantiate a Tokenizer.
+
+ Returns:
+ Tokenizer
+ """
+ return aligner.Tokenizer(self.get_dict())
+
+ def get_aligner(self) -> Aligner:
+ """Instantiate an Aligner.
+
+ Returns:
+ Aligner
+ """
+ return aligner.Aligner(blank=0)
+
+
+MMS_FA = Wav2Vec2FABundle(
+ "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 28,
+ },
+ _labels=utils._get_mms_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+MMS_FA.__doc__ = """
+Trained on 31K hours of data in 1,130 languages from *Scaling Speech Technology to 1,000+ Languages* :cite:`pratap2023scaling`.
+
+Published by the authors of *Scaling Speech Technology to 1,000+ Languages* :cite:`pratap2023scaling` under [`CC-BY-NC 4.0 License `__].
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2FABundle` for usage details.
+
+.. note::
+
+ Unlike other Wav2Vec2 bundles, this model does not have a token for word boundary (like `|`). This makes the post-processing of alignments slightly different.
+""" # noqa: E501
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/utils.py b/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6deab5606e7f0332fb182ffd8d7711ef6b366b0f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/pipelines/_wav2vec2/utils.py
@@ -0,0 +1,346 @@
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+
+from torchaudio._internal import load_state_dict_from_url
+from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model
+
+
+def _get_model(type_, params):
+ factories = {
+ "Wav2Vec2": wav2vec2_model,
+ "WavLM": wavlm_model,
+ }
+ if type_ not in factories:
+ raise ValueError(f"Supported model types are {tuple(factories.keys())}. Found: {type_}")
+ factory = factories[type_]
+ return factory(**params)
+
+
+class _Wav2Vec2Model(nn.Module):
+ """Wrapper class for :py:class:`~torchaudio.models.Wav2Vec2Model`.
+
+ This is used for layer normalization at the input
+ """
+
+ def __init__(self, model: Wav2Vec2Model, normalize_waveform: bool, apply_log_softmax: bool, append_star: bool):
+ super().__init__()
+ self.model = model
+ self.normalize_waveform = normalize_waveform
+ self.apply_log_softmax = apply_log_softmax
+ self.append_star = append_star
+
+ def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
+ if self.normalize_waveform:
+ waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
+ output, output_lengths = self.model(waveforms, lengths)
+ if self.apply_log_softmax:
+ output = torch.nn.functional.log_softmax(output, dim=-1)
+ if self.append_star:
+ star_dim = torch.zeros((1, output.size(1), 1), dtype=output.dtype, device=output.device)
+ output = torch.cat((output, star_dim), dim=-1)
+ return output, output_lengths
+
+ @torch.jit.export
+ def extract_features(
+ self,
+ waveforms: Tensor,
+ lengths: Optional[Tensor] = None,
+ num_layers: Optional[int] = None,
+ ) -> Tuple[List[Tensor], Optional[Tensor]]:
+ if self.normalize_waveform:
+ waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
+ return self.model.extract_features(waveforms, lengths, num_layers)
+
+
+def _extend_model(module, normalize_waveform, apply_log_softmax=False, append_star=False):
+ """Add extra transformations to the model"""
+ return _Wav2Vec2Model(module, normalize_waveform, apply_log_softmax, append_star)
+
+
+def _remove_aux_axes(state_dict, axes):
+ # Remove the seemingly unnecessary axis
+ # For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3
+ # It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks,
+ # but not used during the ASR training.
+ # https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
+ # https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129
+ #
+ # Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and
+ # that resembles mistake.
+ # The label `1` shows up in the training dataset of German (1 out of 16M),
+ # English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M)
+ for key in ["aux.weight", "aux.bias"]:
+ mat = state_dict[key]
+ state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes])
+
+
+def _get_state_dict(url, dl_kwargs, remove_axes=None):
+ if not url.startswith("https"):
+ url = f"https://download.pytorch.org/torchaudio/models/{url}"
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
+ if remove_axes:
+ _remove_aux_axes(state_dict, remove_axes)
+ return state_dict
+
+
+def _get_en_labels():
+ return (
+ "|",
+ "E",
+ "T",
+ "A",
+ "O",
+ "N",
+ "I",
+ "H",
+ "S",
+ "R",
+ "D",
+ "L",
+ "U",
+ "M",
+ "W",
+ "C",
+ "F",
+ "G",
+ "Y",
+ "P",
+ "B",
+ "V",
+ "K",
+ "'",
+ "X",
+ "J",
+ "Q",
+ "Z",
+ )
+
+
+def _get_de_labels():
+ return (
+ "|",
+ "e",
+ "n",
+ "i",
+ "r",
+ "s",
+ "t",
+ "a",
+ "d",
+ "h",
+ "u",
+ "l",
+ "g",
+ "c",
+ "m",
+ "o",
+ "b",
+ "w",
+ "f",
+ "k",
+ "z",
+ "p",
+ "v",
+ "ü",
+ "ä",
+ "ö",
+ "j",
+ "ß",
+ "y",
+ "x",
+ "q",
+ )
+
+
+def _get_vp_en_labels():
+ return (
+ "|",
+ "e",
+ "t",
+ "o",
+ "i",
+ "a",
+ "n",
+ "s",
+ "r",
+ "h",
+ "l",
+ "d",
+ "c",
+ "u",
+ "m",
+ "p",
+ "f",
+ "g",
+ "w",
+ "y",
+ "b",
+ "v",
+ "k",
+ "x",
+ "j",
+ "q",
+ "z",
+ )
+
+
+def _get_es_labels():
+ return (
+ "|",
+ "e",
+ "a",
+ "o",
+ "s",
+ "n",
+ "r",
+ "i",
+ "l",
+ "d",
+ "c",
+ "t",
+ "u",
+ "p",
+ "m",
+ "b",
+ "q",
+ "y",
+ "g",
+ "v",
+ "h",
+ "ó",
+ "f",
+ "í",
+ "á",
+ "j",
+ "z",
+ "ñ",
+ "é",
+ "x",
+ "ú",
+ "k",
+ "w",
+ "ü",
+ )
+
+
+def _get_fr_labels():
+ return (
+ "|",
+ "e",
+ "s",
+ "n",
+ "i",
+ "t",
+ "r",
+ "a",
+ "o",
+ "u",
+ "l",
+ "d",
+ "c",
+ "p",
+ "m",
+ "é",
+ "v",
+ "q",
+ "f",
+ "g",
+ "b",
+ "h",
+ "x",
+ "à",
+ "j",
+ "è",
+ "y",
+ "ê",
+ "z",
+ "ô",
+ "k",
+ "ç",
+ "œ",
+ "û",
+ "ù",
+ "î",
+ "â",
+ "w",
+ "ï",
+ "ë",
+ "ü",
+ "æ",
+ )
+
+
+def _get_it_labels():
+ return (
+ "|",
+ "e",
+ "i",
+ "a",
+ "o",
+ "n",
+ "t",
+ "r",
+ "l",
+ "s",
+ "c",
+ "d",
+ "u",
+ "p",
+ "m",
+ "g",
+ "v",
+ "h",
+ "z",
+ "f",
+ "b",
+ "q",
+ "à",
+ "è",
+ "ù",
+ "é",
+ "ò",
+ "ì",
+ "k",
+ "y",
+ "x",
+ "w",
+ "j",
+ "ó",
+ "í",
+ "ï",
+ )
+
+
+def _get_mms_labels():
+ return (
+ "a",
+ "i",
+ "e",
+ "n",
+ "o",
+ "u",
+ "t",
+ "s",
+ "r",
+ "m",
+ "k",
+ "l",
+ "d",
+ "g",
+ "h",
+ "y",
+ "b",
+ "p",
+ "w",
+ "c",
+ "v",
+ "j",
+ "z",
+ "f",
+ "'",
+ "q",
+ "x",
+ )
diff --git a/MLPY/Lib/site-packages/torchaudio/pipelines/rnnt_pipeline.py b/MLPY/Lib/site-packages/torchaudio/pipelines/rnnt_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a6f50941113d181e9950b3a3c7eadb9c1359a01
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/pipelines/rnnt_pipeline.py
@@ -0,0 +1,380 @@
+import json
+import math
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from functools import partial
+from typing import Callable, List, Tuple
+
+import torch
+import torchaudio
+from torchaudio._internal import module_utils
+from torchaudio.models import emformer_rnnt_base, RNNT, RNNTBeamSearch
+
+
+__all__ = []
+
+_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
+_gain = pow(10, 0.05 * _decibel)
+
+
+def _piecewise_linear_log(x):
+ x[x > math.e] = torch.log(x[x > math.e])
+ x[x <= math.e] = x[x <= math.e] / math.e
+ return x
+
+
+class _FunctionalModule(torch.nn.Module):
+ def __init__(self, functional):
+ super().__init__()
+ self.functional = functional
+
+ def forward(self, input):
+ return self.functional(input)
+
+
+class _GlobalStatsNormalization(torch.nn.Module):
+ def __init__(self, global_stats_path):
+ super().__init__()
+
+ with open(global_stats_path) as f:
+ blob = json.loads(f.read())
+
+ self.register_buffer("mean", torch.tensor(blob["mean"]))
+ self.register_buffer("invstddev", torch.tensor(blob["invstddev"]))
+
+ def forward(self, input):
+ return (input - self.mean) * self.invstddev
+
+
+class _FeatureExtractor(ABC):
+ @abstractmethod
+ def __call__(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Generates features and length output from the given input tensor.
+
+ Args:
+ input (torch.Tensor): input tensor.
+
+ Returns:
+ (torch.Tensor, torch.Tensor):
+ torch.Tensor:
+ Features, with shape `(length, *)`.
+ torch.Tensor:
+ Length, with shape `(1,)`.
+ """
+
+
+class _TokenProcessor(ABC):
+ @abstractmethod
+ def __call__(self, tokens: List[int], **kwargs) -> str:
+ """Decodes given list of tokens to text sequence.
+
+ Args:
+ tokens (List[int]): list of tokens to decode.
+
+ Returns:
+ str:
+ Decoded text sequence.
+ """
+
+
+class _ModuleFeatureExtractor(torch.nn.Module, _FeatureExtractor):
+ """``torch.nn.Module``-based feature extraction pipeline.
+
+ Args:
+ pipeline (torch.nn.Module): module that implements feature extraction logic.
+ """
+
+ def __init__(self, pipeline: torch.nn.Module) -> None:
+ super().__init__()
+ self.pipeline = pipeline
+
+ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Generates features and length output from the given input tensor.
+
+ Args:
+ input (torch.Tensor): input tensor.
+
+ Returns:
+ (torch.Tensor, torch.Tensor):
+ torch.Tensor:
+ Features, with shape `(length, *)`.
+ torch.Tensor:
+ Length, with shape `(1,)`.
+ """
+ features = self.pipeline(input)
+ length = torch.tensor([features.shape[0]])
+ return features, length
+
+
+class _SentencePieceTokenProcessor(_TokenProcessor):
+ """SentencePiece-model-based token processor.
+
+ Args:
+ sp_model_path (str): path to SentencePiece model.
+ """
+
+ def __init__(self, sp_model_path: str) -> None:
+ if not module_utils.is_module_available("sentencepiece"):
+ raise RuntimeError("SentencePiece is not available. Please install it.")
+
+ import sentencepiece as spm
+
+ self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
+ self.post_process_remove_list = {
+ self.sp_model.unk_id(),
+ self.sp_model.eos_id(),
+ self.sp_model.pad_id(),
+ }
+
+ def __call__(self, tokens: List[int], lstrip: bool = True) -> str:
+ """Decodes given list of tokens to text sequence.
+
+ Args:
+ tokens (List[int]): list of tokens to decode.
+ lstrip (bool, optional): if ``True``, returns text sequence with leading whitespace
+ removed. (Default: ``True``).
+
+ Returns:
+ str:
+ Decoded text sequence.
+ """
+ filtered_hypo_tokens = [
+ token_index for token_index in tokens[1:] if token_index not in self.post_process_remove_list
+ ]
+ output_string = "".join(self.sp_model.id_to_piece(filtered_hypo_tokens)).replace("\u2581", " ")
+
+ if lstrip:
+ return output_string.lstrip()
+ else:
+ return output_string
+
+
+@dataclass
+class RNNTBundle:
+ """Dataclass that bundles components for performing automatic speech recognition (ASR, speech-to-text)
+ inference with an RNN-T model.
+
+ More specifically, the class provides methods that produce the featurization pipeline,
+ decoder wrapping the specified RNN-T model, and output token post-processor that together
+ constitute a complete end-to-end ASR inference pipeline that produces a text sequence
+ given a raw waveform.
+
+ It can support non-streaming (full-context) inference as well as streaming inference.
+
+ Users should not directly instantiate objects of this class; rather, users should use the
+ instances (representing pre-trained models) that exist within the module,
+ e.g. :data:`torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH`.
+
+ Example
+ >>> import torchaudio
+ >>> from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
+ >>> import torch
+ >>>
+ >>> # Non-streaming inference.
+ >>> # Build feature extractor, decoder with RNN-T model, and token processor.
+ >>> feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_feature_extractor()
+ 100%|███████████████████████████████| 3.81k/3.81k [00:00<00:00, 4.22MB/s]
+ >>> decoder = EMFORMER_RNNT_BASE_LIBRISPEECH.get_decoder()
+ Downloading: "https://download.pytorch.org/torchaudio/models/emformer_rnnt_base_librispeech.pt"
+ 100%|███████████████████████████████| 293M/293M [00:07<00:00, 42.1MB/s]
+ >>> token_processor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_token_processor()
+ 100%|███████████████████████████████| 295k/295k [00:00<00:00, 25.4MB/s]
+ >>>
+ >>> # Instantiate LibriSpeech dataset; retrieve waveform for first sample.
+ >>> dataset = torchaudio.datasets.LIBRISPEECH("/home/librispeech", url="test-clean")
+ >>> waveform = next(iter(dataset))[0].squeeze()
+ >>>
+ >>> with torch.no_grad():
+ >>> # Produce mel-scale spectrogram features.
+ >>> features, length = feature_extractor(waveform)
+ >>>
+ >>> # Generate top-10 hypotheses.
+ >>> hypotheses = decoder(features, length, 10)
+ >>>
+ >>> # For top hypothesis, convert predicted tokens to text.
+ >>> text = token_processor(hypotheses[0][0])
+ >>> print(text)
+ he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to [...]
+ >>>
+ >>>
+ >>> # Streaming inference.
+ >>> hop_length = EMFORMER_RNNT_BASE_LIBRISPEECH.hop_length
+ >>> num_samples_segment = EMFORMER_RNNT_BASE_LIBRISPEECH.segment_length * hop_length
+ >>> num_samples_segment_right_context = (
+ >>> num_samples_segment + EMFORMER_RNNT_BASE_LIBRISPEECH.right_context_length * hop_length
+ >>> )
+ >>>
+ >>> # Build streaming inference feature extractor.
+ >>> streaming_feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_streaming_feature_extractor()
+ >>>
+ >>> # Process same waveform as before, this time sequentially across overlapping segments
+ >>> # to simulate streaming inference. Note the usage of ``streaming_feature_extractor`` and ``decoder.infer``.
+ >>> state, hypothesis = None, None
+ >>> for idx in range(0, len(waveform), num_samples_segment):
+ >>> segment = waveform[idx: idx + num_samples_segment_right_context]
+ >>> segment = torch.nn.functional.pad(segment, (0, num_samples_segment_right_context - len(segment)))
+ >>> with torch.no_grad():
+ >>> features, length = streaming_feature_extractor(segment)
+ >>> hypotheses, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
+ >>> hypothesis = hypotheses[0]
+ >>> transcript = token_processor(hypothesis[0])
+ >>> if transcript:
+ >>> print(transcript, end=" ", flush=True)
+ he hoped there would be stew for dinner turn ips and car rots and bru 'd oes and fat mut ton pieces to [...]
+ """
+
+ class FeatureExtractor(_FeatureExtractor):
+ """Interface of the feature extraction part of RNN-T pipeline"""
+
+ class TokenProcessor(_TokenProcessor):
+ """Interface of the token processor part of RNN-T pipeline"""
+
+ _rnnt_path: str
+ _rnnt_factory_func: Callable[[], RNNT]
+ _global_stats_path: str
+ _sp_model_path: str
+ _right_padding: int
+ _blank: int
+ _sample_rate: int
+ _n_fft: int
+ _n_mels: int
+ _hop_length: int
+ _segment_length: int
+ _right_context_length: int
+
+ def _get_model(self) -> RNNT:
+ model = self._rnnt_factory_func()
+ path = torchaudio.utils.download_asset(self._rnnt_path)
+ state_dict = torch.load(path)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+ @property
+ def sample_rate(self) -> int:
+ """Sample rate (in cycles per second) of input waveforms.
+
+ :type: int
+ """
+ return self._sample_rate
+
+ @property
+ def n_fft(self) -> int:
+ """Size of FFT window to use.
+
+ :type: int
+ """
+ return self._n_fft
+
+ @property
+ def n_mels(self) -> int:
+ """Number of mel spectrogram features to extract from input waveforms.
+
+ :type: int
+ """
+ return self._n_mels
+
+ @property
+ def hop_length(self) -> int:
+ """Number of samples between successive frames in input expected by model.
+
+ :type: int
+ """
+ return self._hop_length
+
+ @property
+ def segment_length(self) -> int:
+ """Number of frames in segment in input expected by model.
+
+ :type: int
+ """
+ return self._segment_length
+
+ @property
+ def right_context_length(self) -> int:
+ """Number of frames in right contextual block in input expected by model.
+
+ :type: int
+ """
+ return self._right_context_length
+
+ def get_decoder(self) -> RNNTBeamSearch:
+ """Constructs RNN-T decoder.
+
+ Returns:
+ RNNTBeamSearch
+ """
+ model = self._get_model()
+ return RNNTBeamSearch(model, self._blank)
+
+ def get_feature_extractor(self) -> FeatureExtractor:
+ """Constructs feature extractor for non-streaming (full-context) ASR.
+
+ Returns:
+ FeatureExtractor
+ """
+ local_path = torchaudio.utils.download_asset(self._global_stats_path)
+ return _ModuleFeatureExtractor(
+ torch.nn.Sequential(
+ torchaudio.transforms.MelSpectrogram(
+ sample_rate=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, hop_length=self.hop_length
+ ),
+ _FunctionalModule(lambda x: x.transpose(1, 0)),
+ _FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
+ _GlobalStatsNormalization(local_path),
+ _FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 0, 0, self._right_padding))),
+ )
+ )
+
+ def get_streaming_feature_extractor(self) -> FeatureExtractor:
+ """Constructs feature extractor for streaming (simultaneous) ASR.
+
+ Returns:
+ FeatureExtractor
+ """
+ local_path = torchaudio.utils.download_asset(self._global_stats_path)
+ return _ModuleFeatureExtractor(
+ torch.nn.Sequential(
+ torchaudio.transforms.MelSpectrogram(
+ sample_rate=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, hop_length=self.hop_length
+ ),
+ _FunctionalModule(lambda x: x.transpose(1, 0)),
+ _FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
+ _GlobalStatsNormalization(local_path),
+ )
+ )
+
+ def get_token_processor(self) -> TokenProcessor:
+ """Constructs token processor.
+
+ Returns:
+ TokenProcessor
+ """
+ local_path = torchaudio.utils.download_asset(self._sp_model_path)
+ return _SentencePieceTokenProcessor(local_path)
+
+
+EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle(
+ _rnnt_path="models/emformer_rnnt_base_librispeech.pt",
+ _rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=4097),
+ _global_stats_path="pipeline-assets/global_stats_rnnt_librispeech.json",
+ _sp_model_path="pipeline-assets/spm_bpe_4096_librispeech.model",
+ _right_padding=4,
+ _blank=4096,
+ _sample_rate=16000,
+ _n_fft=400,
+ _n_mels=80,
+ _hop_length=160,
+ _segment_length=16,
+ _right_context_length=4,
+)
+EMFORMER_RNNT_BASE_LIBRISPEECH.__doc__ = """ASR pipeline based on Emformer-RNNT,
+pretrained on *LibriSpeech* dataset :cite:`7178964`,
+capable of performing both streaming and non-streaming inference.
+
+The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
+and utilizes weights trained on LibriSpeech using training script ``train.py``
+`here `__ with default arguments.
+
+Please refer to :py:class:`RNNTBundle` for usage instructions.
+"""
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/__init__.py b/MLPY/Lib/site-packages/torchaudio/prototype/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ef33317053f3989a3590c081785439c031f103e
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/datasets/__init__.py b/MLPY/Lib/site-packages/torchaudio/prototype/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..66123a85ed4d8dd27e0eedaa8a0c8e5552da6348
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/datasets/__init__.py
@@ -0,0 +1,4 @@
+from .musan import Musan
+
+
+__all__ = ["Musan"]
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/datasets/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/datasets/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..69e89fcaa26f3c4b7f744684089177b62252dc48
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/datasets/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/datasets/__pycache__/musan.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/datasets/__pycache__/musan.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5c7b7f7e7e85e86be3e25526974cf392d13dca9
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/datasets/__pycache__/musan.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/datasets/musan.py b/MLPY/Lib/site-packages/torchaudio/prototype/datasets/musan.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e8fe3d6a1342378c70108cad7679d8fd64b7c7a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/datasets/musan.py
@@ -0,0 +1,67 @@
+from pathlib import Path
+from typing import Tuple, Union
+
+import torch
+from torch.utils.data import Dataset
+from torchaudio.datasets.utils import _load_waveform
+
+
+_SUBSETS = ["music", "noise", "speech"]
+_SAMPLE_RATE = 16_000
+
+
+class Musan(Dataset):
+ r"""*MUSAN* :cite:`musan2015` dataset.
+
+ Args:
+ root (str or Path): Root directory where the dataset's top-level directory exists.
+ subset (str): Subset of the dataset to use. Options: [``"music"``, ``"noise"``, ``"speech"``].
+ """
+
+ def __init__(self, root: Union[str, Path], subset: str):
+ if subset not in _SUBSETS:
+ raise ValueError(f"Invalid subset '{subset}' given. Please provide one of {_SUBSETS}")
+
+ subset_path = Path(root) / subset
+ self._walker = [str(p) for p in subset_path.glob("*/*.*")]
+
+ def get_metadata(self, n: int) -> Tuple[str, int, str]:
+ r"""Get metadata for the n-th sample in the dataset. Returns filepath instead of waveform,
+ but otherwise returns the same fields as :py:func:`__getitem__`.
+
+ Args:
+ n (int): Index of sample to be loaded.
+
+ Returns:
+ (str, int, str):
+ str
+ Path to audio.
+ int
+ Sample rate.
+ str
+ File name.
+ """
+ audio_path = self._walker[n]
+ return audio_path, _SAMPLE_RATE, Path(audio_path).name
+
+ def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str]:
+ r"""Return the n-th sample in the dataset.
+
+ Args:
+ n (int): Index of sample to be loaded.
+
+ Returns:
+ (torch.Tensor, int, str):
+ torch.Tensor
+ Waveform.
+ int
+ Sample rate.
+ str
+ File name.
+ """
+ audio_path, sample_rate, filename = self.get_metadata(n)
+ path = Path(audio_path)
+ return _load_waveform(path.parent, path.name, sample_rate), sample_rate, filename
+
+ def __len__(self) -> int:
+ return len(self._walker)
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/functional/__init__.py b/MLPY/Lib/site-packages/torchaudio/prototype/functional/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..11d7daa6a6069c0caca0b191fe174af78e1dab72
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/functional/__init__.py
@@ -0,0 +1,26 @@
+from ._dsp import (
+ adsr_envelope,
+ exp_sigmoid,
+ extend_pitch,
+ filter_waveform,
+ frequency_impulse_response,
+ oscillator_bank,
+ sinc_impulse_response,
+)
+from ._rir import ray_tracing, simulate_rir_ism
+from .functional import barkscale_fbanks, chroma_filterbank
+
+
+__all__ = [
+ "adsr_envelope",
+ "exp_sigmoid",
+ "barkscale_fbanks",
+ "chroma_filterbank",
+ "extend_pitch",
+ "filter_waveform",
+ "frequency_impulse_response",
+ "oscillator_bank",
+ "ray_tracing",
+ "sinc_impulse_response",
+ "simulate_rir_ism",
+]
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/functional/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/functional/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3da9ea9b38294448a38ba8ad09359a1b25680e22
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/functional/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/functional/__pycache__/_dsp.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/functional/__pycache__/_dsp.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..945b0aae5e8f2bd733a02998935fd3c4be9d8702
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/functional/__pycache__/_dsp.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/functional/__pycache__/_rir.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/functional/__pycache__/_rir.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..63922e59535d22d938464f59ce12907b2ec7395d
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/functional/__pycache__/_rir.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/functional/__pycache__/functional.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/functional/__pycache__/functional.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a521c7e98378dc3a167ba2db1ea5d3c8bbb1a2e7
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/functional/__pycache__/functional.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/functional/_dsp.py b/MLPY/Lib/site-packages/torchaudio/prototype/functional/_dsp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c590374b3fe8ce58479ed0c0388a551ad8765004
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/functional/_dsp.py
@@ -0,0 +1,433 @@
+import warnings
+from typing import List, Optional, Union
+
+import torch
+
+from torchaudio.functional import fftconvolve
+
+
+def oscillator_bank(
+ frequencies: torch.Tensor,
+ amplitudes: torch.Tensor,
+ sample_rate: float,
+ reduction: str = "sum",
+ dtype: Optional[torch.dtype] = torch.float64,
+) -> torch.Tensor:
+ """Synthesize waveform from the given instantaneous frequencies and amplitudes.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Note:
+ The phase information of the output waveform is found by taking the cumulative sum
+ of the given instantaneous frequencies (``frequencies``).
+ This incurs roundoff error when the data type does not have enough precision.
+ Using ``torch.float64`` can work around this.
+
+ The following figure shows the difference between ``torch.float32`` and
+ ``torch.float64`` when generating a sin wave of constant frequency and amplitude
+ with sample rate 8000 [Hz].
+ Notice that ``torch.float32`` version shows artifacts that are not seen in
+ ``torch.float64`` version.
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/oscillator_precision.png
+
+ Args:
+ frequencies (Tensor): Sample-wise oscillator frequencies (Hz). Shape `(..., time, N)`.
+ amplitudes (Tensor): Sample-wise oscillator amplitude. Shape: `(..., time, N)`.
+ sample_rate (float): Sample rate
+ reduction (str): Reduction to perform.
+ Valid values are ``"sum"``, ``"mean"`` or ``"none"``. Default: ``"sum"``
+ dtype (torch.dtype or None, optional): The data type on which cumulative sum operation is performed.
+ Default: ``torch.float64``. Pass ``None`` to disable the casting.
+
+ Returns:
+ Tensor:
+ The resulting waveform.
+
+ If ``reduction`` is ``"none"``, then the shape is
+ `(..., time, N)`, otherwise the shape is `(..., time)`.
+ """
+ if frequencies.shape != amplitudes.shape:
+ raise ValueError(
+ "The shapes of `frequencies` and `amplitudes` must match. "
+ f"Found: {frequencies.shape} and {amplitudes.shape} respectively."
+ )
+ reductions = ["sum", "mean", "none"]
+ if reduction not in reductions:
+ raise ValueError(f"The value of reduction must be either {reductions}. Found: {reduction}")
+
+ invalid = torch.abs(frequencies) >= sample_rate / 2
+ if torch.any(invalid):
+ warnings.warn(
+ "Some frequencies are above nyquist frequency. "
+ "Setting the corresponding amplitude to zero. "
+ "This might cause numerically unstable gradient."
+ )
+ amplitudes = torch.where(invalid, 0.0, amplitudes)
+
+ pi2 = 2.0 * torch.pi
+ freqs = frequencies * pi2 / sample_rate % pi2
+ phases = torch.cumsum(freqs, dim=-2, dtype=dtype)
+ if dtype is not None and freqs.dtype != dtype:
+ phases = phases.to(freqs.dtype)
+
+ waveform = amplitudes * torch.sin(phases)
+ if reduction == "sum":
+ return waveform.sum(-1)
+ if reduction == "mean":
+ return waveform.mean(-1)
+ return waveform
+
+
+def adsr_envelope(
+ num_frames: int,
+ *,
+ attack: float = 0.0,
+ hold: float = 0.0,
+ decay: float = 0.0,
+ sustain: float = 1.0,
+ release: float = 0.0,
+ n_decay: int = 2,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+):
+ """Generate ADSR Envelope
+
+ .. devices:: CPU CUDA
+
+ Args:
+ num_frames (int): The number of output frames.
+ attack (float, optional):
+ The relative *time* it takes to reach the maximum level from
+ the start. (Default: ``0.0``)
+ hold (float, optional):
+ The relative *time* the maximum level is held before
+ it starts to decay. (Default: ``0.0``)
+ decay (float, optional):
+ The relative *time* it takes to sustain from
+ the maximum level. (Default: ``0.0``)
+ sustain (float, optional): The relative *level* at which
+ the sound should sustain. (Default: ``1.0``)
+
+ .. Note::
+ The duration of sustain is derived as `1.0 - (The sum of attack, hold, decay and release)`.
+
+ release (float, optional): The relative *time* it takes for the sound level to
+ reach zero after the sustain. (Default: ``0.0``)
+ n_decay (int, optional): The degree of polynomial decay. Default: ``2``.
+ dtype (torch.dtype, optional): the desired data type of returned tensor.
+ Default: if ``None``, uses a global default
+ (see :py:func:`torch.set_default_tensor_type`).
+ device (torch.device, optional): the desired device of returned tensor.
+ Default: if ``None``, uses the current device for the default tensor type
+ (see :py:func:`torch.set_default_tensor_type`).
+ device will be the CPU for CPU tensor types and the current CUDA
+ device for CUDA tensor types.
+
+ Returns:
+ Tensor: ADSR Envelope. Shape: `(num_frames, )`
+
+ Example
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/adsr_examples.png
+
+ """
+ if not 0 <= attack <= 1:
+ raise ValueError(f"The value of `attack` must be within [0, 1]. Found: {attack}")
+ if not 0 <= decay <= 1:
+ raise ValueError(f"The value of `decay` must be within [0, 1]. Found: {decay}")
+ if not 0 <= sustain <= 1:
+ raise ValueError(f"The value of `sustain` must be within [0, 1]. Found: {sustain}")
+ if not 0 <= hold <= 1:
+ raise ValueError(f"The value of `hold` must be within [0, 1]. Found: {hold}")
+ if not 0 <= release <= 1:
+ raise ValueError(f"The value of `release` must be within [0, 1]. Found: {release}")
+ if attack + decay + release + hold > 1:
+ raise ValueError("The sum of `attack`, `hold`, `decay` and `release` must not exceed 1.")
+
+ nframes = num_frames - 1
+ num_a = int(nframes * attack)
+ num_h = int(nframes * hold)
+ num_d = int(nframes * decay)
+ num_r = int(nframes * release)
+
+ # Initialize with sustain
+ out = torch.full((num_frames,), float(sustain), device=device, dtype=dtype)
+
+ # attack
+ if num_a > 0:
+ torch.linspace(0.0, 1.0, num_a + 1, out=out[: num_a + 1])
+
+ # hold
+ if num_h > 0:
+ out[num_a : num_a + num_h + 1] = 1.0
+
+ # decay
+ if num_d > 0:
+ # Compute: sustain + (1.0 - sustain) * (linspace[1, 0] ** n_decay)
+ i = num_a + num_h
+ decay = out[i : i + num_d + 1]
+ torch.linspace(1.0, 0.0, num_d + 1, out=decay)
+ decay **= n_decay
+ decay *= 1.0 - sustain
+ decay += sustain
+
+ # sustain is handled by initialization
+
+ # release
+ if num_r > 0:
+ torch.linspace(sustain, 0, num_r + 1, out=out[-num_r - 1 :])
+
+ return out
+
+
+def extend_pitch(
+ base: torch.Tensor,
+ pattern: Union[int, List[float], torch.Tensor],
+):
+ """Extend the given time series values with multipliers of them.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Given a series of fundamental frequencies (pitch), this function appends
+ its harmonic overtones or inharmonic partials.
+
+ Args:
+ base (torch.Tensor):
+ Base time series, like fundamental frequencies (Hz). Shape: `(..., time, 1)`.
+ pattern (int, list of floats or torch.Tensor):
+ If ``int``, the number of pitch series after the operation.
+ `pattern - 1` tones are added, so that the resulting Tensor contains
+ up to `pattern`-th overtones of the given series.
+
+ If list of float or ``torch.Tensor``, it must be one dimensional,
+ representing the custom multiplier of the fundamental frequency.
+
+ Returns:
+ Tensor: Oscillator frequencies (Hz). Shape: `(..., time, num_tones)`.
+
+ Example
+ >>> # fundamental frequency
+ >>> f0 = torch.linspace(1, 5, 5).unsqueeze(-1)
+ >>> f0
+ tensor([[1.],
+ [2.],
+ [3.],
+ [4.],
+ [5.]])
+ >>> # Add harmonic overtones, up to 3rd.
+ >>> f = extend_pitch(f0, 3)
+ >>> f.shape
+ torch.Size([5, 3])
+ >>> f
+ tensor([[ 1., 2., 3.],
+ [ 2., 4., 6.],
+ [ 3., 6., 9.],
+ [ 4., 8., 12.],
+ [ 5., 10., 15.]])
+ >>> # Add custom (inharmonic) partials.
+ >>> f = extend_pitch(f0, torch.tensor([1, 2.1, 3.3, 4.5]))
+ >>> f.shape
+ torch.Size([5, 4])
+ >>> f
+ tensor([[ 1.0000, 2.1000, 3.3000, 4.5000],
+ [ 2.0000, 4.2000, 6.6000, 9.0000],
+ [ 3.0000, 6.3000, 9.9000, 13.5000],
+ [ 4.0000, 8.4000, 13.2000, 18.0000],
+ [ 5.0000, 10.5000, 16.5000, 22.5000]])
+ """
+ if isinstance(pattern, torch.Tensor):
+ mult = pattern
+ elif isinstance(pattern, int):
+ mult = torch.linspace(1.0, float(pattern), pattern, device=base.device, dtype=base.dtype)
+ else:
+ mult = torch.tensor(pattern, dtype=base.dtype, device=base.device)
+ h_freq = base @ mult.unsqueeze(0)
+ return h_freq
+
+
+def sinc_impulse_response(cutoff: torch.Tensor, window_size: int = 513, high_pass: bool = False):
+ """Create windowed-sinc impulse response for given cutoff frequencies.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ cutoff (Tensor): Cutoff frequencies for low-pass sinc filter.
+
+ window_size (int, optional): Size of the Hamming window to apply. Must be odd.
+ (Default: 513)
+
+ high_pass (bool, optional):
+ If ``True``, convert the resulting filter to high-pass.
+ Otherwise low-pass filter is returned. Default: ``False``.
+
+ Returns:
+ Tensor: A series of impulse responses. Shape: `(..., window_size)`.
+ """
+ if window_size % 2 == 0:
+ raise ValueError(f"`window_size` must be odd. Given: {window_size}")
+
+ half = window_size // 2
+ device, dtype = cutoff.device, cutoff.dtype
+ idx = torch.linspace(-half, half, window_size, device=device, dtype=dtype)
+
+ filt = torch.special.sinc(cutoff.unsqueeze(-1) * idx.unsqueeze(0))
+ filt = filt * torch.hamming_window(window_size, device=device, dtype=dtype, periodic=False).unsqueeze(0)
+ filt = filt / filt.sum(dim=-1, keepdim=True).abs()
+
+ # High pass IR is obtained by subtracting low_pass IR from delta function.
+ # https://courses.engr.illinois.edu/ece401/fa2020/slides/lec10.pdf
+ if high_pass:
+ filt = -filt
+ filt[..., half] = 1.0 + filt[..., half]
+ return filt
+
+
+def frequency_impulse_response(magnitudes):
+ """Create filter from desired frequency response
+
+ Args:
+ magnitudes: The desired frequency responses. Shape: `(..., num_fft_bins)`
+
+ Returns:
+ Tensor: Impulse response. Shape `(..., 2 * (num_fft_bins - 1))`
+ """
+ if magnitudes.min() < 0.0:
+ # Negative magnitude does not make sense but allowing so that autograd works
+ # around 0.
+ # Should we raise error?
+ warnings.warn("The input frequency response should not contain negative values.")
+ ir = torch.fft.fftshift(torch.fft.irfft(magnitudes), dim=-1)
+ device, dtype = magnitudes.device, magnitudes.dtype
+ window = torch.hann_window(ir.size(-1), periodic=False, device=device, dtype=dtype).expand_as(ir)
+ return ir * window
+
+
+def _overlap_and_add(waveform, stride):
+ num_frames, frame_size = waveform.shape[-2:]
+ numel = (num_frames - 1) * stride + frame_size
+ buffer = torch.zeros(waveform.shape[:-2] + (numel,), device=waveform.device, dtype=waveform.dtype)
+ for i in range(num_frames):
+ start = i * stride
+ end = start + frame_size
+ buffer[..., start:end] += waveform[..., i, :]
+ return buffer
+
+
+def filter_waveform(waveform: torch.Tensor, kernels: torch.Tensor, delay_compensation: int = -1):
+ """Applies filters along time axis of the given waveform.
+
+ This function applies the given filters along time axis in the following manner:
+
+ 1. Split the given waveform into chunks. The number of chunks is equal to the number of given filters.
+ 2. Filter each chunk with corresponding filter.
+ 3. Place the filtered chunks at the original indices while adding up the overlapping parts.
+ 4. Crop the resulting waveform so that delay introduced by the filter is removed and its length
+ matches that of the input waveform.
+
+ The following figure illustrates this.
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/filter_waveform.png
+
+ .. note::
+
+ If the number of filters is one, then the operation becomes stationary.
+ i.e. the same filtering is applied across the time axis.
+
+ Args:
+ waveform (Tensor): Shape `(..., time)`.
+ kernels (Tensor): Impulse responses.
+ Valid inputs are 2D tensor with shape `(num_filters, filter_length)` or
+ `(N+1)`-D tensor with shape `(..., num_filters, filter_length)`, where `N` is
+ the dimension of waveform.
+
+ In case of 2D input, the same set of filters is used across channels and batches.
+ Otherwise, different sets of filters are applied. In this case, the shape of
+ the first `N-1` dimensions of filters must match (or be broadcastable to) that of waveform.
+
+ delay_compensation (int): Control how the waveform is cropped after full convolution.
+ If the value is zero or positive, it is interpreted as the length of crop at the
+ beginning of the waveform. The value cannot be larger than the size of filter kernel.
+ Otherwise the initial crop is ``filter_size // 2``.
+ When cropping happens, the waveform is also cropped from the end so that the
+ length of the resulting waveform matches the input waveform.
+
+ Returns:
+ Tensor: `(..., time)`.
+ """
+ if kernels.ndim not in [2, waveform.ndim + 1]:
+ raise ValueError(
+ "`kernels` must be 2 or N+1 dimension where "
+ f"N is the dimension of waveform. Found: {kernels.ndim} (N={waveform.ndim})"
+ )
+
+ num_filters, filter_size = kernels.shape[-2:]
+ num_frames = waveform.size(-1)
+
+ if delay_compensation > filter_size:
+ raise ValueError(
+ "When `delay_compenstation` is provided, it cannot be larger than the size of filters."
+ f"Found: delay_compensation={delay_compensation}, filter_size={filter_size}"
+ )
+
+ # Transform waveform's time axis into (num_filters x chunk_length) with optional padding
+ chunk_length = num_frames // num_filters
+ if num_frames % num_filters > 0:
+ chunk_length += 1
+ num_pad = chunk_length * num_filters - num_frames
+ waveform = torch.nn.functional.pad(waveform, [0, num_pad], "constant", 0)
+ chunked = waveform.unfold(-1, chunk_length, chunk_length)
+ assert chunked.numel() >= waveform.numel()
+
+ # Broadcast kernels
+ if waveform.ndim + 1 > kernels.ndim:
+ expand_shape = waveform.shape[:-1] + kernels.shape
+ kernels = kernels.expand(expand_shape)
+
+ convolved = fftconvolve(chunked, kernels)
+ restored = _overlap_and_add(convolved, chunk_length)
+
+ # Trim in a way that the number of samples are same as input,
+ # and the filter delay is compensated
+ if delay_compensation >= 0:
+ start = delay_compensation
+ else:
+ start = filter_size // 2
+ num_crops = restored.size(-1) - num_frames
+ end = num_crops - start
+ result = restored[..., start:-end]
+ return result
+
+
+def exp_sigmoid(
+ input: torch.Tensor, exponent: float = 10.0, max_value: float = 2.0, threshold: float = 1e-7
+) -> torch.Tensor:
+ """Exponential Sigmoid pointwise nonlinearity.
+ Implements the equation:
+ ``max_value`` * sigmoid(``input``) ** (log(``exponent``)) + ``threshold``
+
+ The output has a range of [``threshold``, ``max_value``].
+ ``exponent`` controls the slope of the output.
+
+ .. devices:: CPU CUDA
+
+ Args:
+ input (Tensor): Input Tensor
+ exponent (float, optional): Exponent. Controls the slope of the output
+ max_value (float, optional): Maximum value of the output
+ threshold (float, optional): Minimum value of the output
+
+ Returns:
+ Tensor: Exponential Sigmoid output. Shape: same as input
+
+ """
+
+ return max_value * torch.pow(
+ torch.nn.functional.sigmoid(input),
+ torch.log(torch.tensor(exponent, device=input.device, dtype=input.dtype)),
+ ) + torch.tensor(threshold, device=input.device, dtype=input.dtype)
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/functional/_rir.py b/MLPY/Lib/site-packages/torchaudio/prototype/functional/_rir.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6c54974e4ee9396c8f76bd6fec9ceb1992e4426
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/functional/_rir.py
@@ -0,0 +1,379 @@
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+import torchaudio
+from torch import Tensor
+
+
+def _compute_image_sources(
+ room: torch.Tensor,
+ source: torch.Tensor,
+ max_order: int,
+ absorption: torch.Tensor,
+ scatter: Optional[torch.Tensor] = None,
+) -> Tuple[Tensor, Tensor]:
+ """Compute image sources in a shoebox-like room.
+
+ Args:
+ room (torch.Tensor): The 1D Tensor to determine the room size. The shape is
+ `(D,)`, where ``D`` is 2 if room is a 2D room, or 3 if room is a 3D room.
+ source (torch.Tensor): The coordinate of the sound source. Tensor with dimensions
+ `(D)`.
+ max_order (int): The maximum number of reflections of the source.
+ absorption (torch.Tensor): The absorption coefficients of wall materials.
+ ``absorption`` is a Tensor with dimensions `(num_band, num_wall)`.
+ The shape options are ``[(1, 4), (1, 6), (7, 4), (7, 6)]``.
+ ``num_band`` is `1` if the coefficients is the same for all frequencies, or is `7`
+ if the coefficients are different to different frequencies. `7` refers to the default number
+ of octave bands. (See note in `simulate_rir_ism` method).
+ ``num_wall`` is `4` if the room is a 2D room, representing absorption coefficients
+ of ``"west"``, ``"east"``, ``"south"``, and ``"north"`` walls, respectively.
+ Or it is `6` if the room is a 3D room, representing absorption coefficients
+ of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and ``"ceiling"``, respectively.
+ scatter (torch.Tensor): The scattering coefficients of wall materials.
+ The shape of ``scatter`` must match that of ``absorption``. If ``None``, it is not
+ used in image source computation. (Default: ``None``)
+
+ Returns:
+ (torch.Tensor): The coordinates of all image sources within ``max_order`` number of reflections.
+ Tensor with dimensions `(num_image_source, D)`.
+ (torch.Tensor): The attenuation of corresponding image sources. Tensor with dimensions
+ `(num_band, num_image_source)`.
+ """
+ if scatter is None:
+ tr = torch.sqrt(1 - absorption)
+ else:
+ tr = torch.sqrt(1 - absorption) * torch.sqrt(1 - scatter)
+
+ ind = torch.arange(-max_order, max_order + 1, device=source.device)
+ if room.shape[0] == 2:
+ XYZ = torch.meshgrid(ind, ind, indexing="ij")
+ else:
+ XYZ = torch.meshgrid(ind, ind, ind, indexing="ij")
+ XYZ = torch.stack([c.reshape((-1,)) for c in XYZ], dim=-1)
+ XYZ = XYZ[XYZ.abs().sum(dim=-1) <= max_order]
+
+ # compute locations of image sources
+ d = room[None, :]
+ s = source[None, :]
+ img_loc = torch.where(XYZ % 2 == 1, d * (XYZ + 1) - s, d * XYZ + s)
+
+ # attenuation
+ exp_lo = abs(torch.floor((XYZ / 2)))
+ exp_hi = abs(torch.floor((XYZ + 1) / 2))
+ t_lo = tr[:, ::2].unsqueeze(1).repeat(1, XYZ.shape[0], 1) # (num_band, left walls)
+ t_hi = tr[:, 1::2].unsqueeze(1).repeat(1, XYZ.shape[0], 1) # (num_band, right walls)
+ att = torch.prod((t_lo**exp_lo) * (t_hi**exp_hi), dim=-1) # (num_band, num_image_source)
+ return img_loc, att
+
+
+def _hann(x: torch.Tensor, T: int):
+ """Compute the Hann window where the values are truncated based on window length.
+ torch.hann_window can only sample window function at integer points, the method is to sample
+ continuous window function at non-integer points.
+
+ Args:
+ x (torch.Tensor): The fractional component of time delay Tensor.
+ T (torch.Tensor): The window length of sinc function.
+
+ Returns:
+ (torch.Tensor): The hann window Tensor where values outside
+ the sinc window (`T`) is set to zero.
+ """
+ y = torch.where(
+ torch.abs(x) <= T / 2,
+ 0.5 * (1 + torch.cos(2 * math.pi * x / T)),
+ x.new_zeros(1),
+ )
+ return y
+
+
+def _frac_delay(delay: torch.Tensor, delay_i: torch.Tensor, delay_filter_length: int):
+ """Compute fractional delay of impulse response signal.
+
+ Args:
+ delay (torch.Tensor): The time delay Tensor in samples.
+ delay_i (torch.Tensor): The integer part of delay.
+ delay_filter_length (int): The window length for sinc function.
+
+ Returns:
+ (torch.Tensor): The impulse response Tensor for all image sources.
+ """
+ if delay_filter_length % 2 != 1:
+ raise ValueError("The filter length must be odd")
+
+ pad = delay_filter_length // 2
+ n = torch.arange(-pad, pad + 1, device=delay.device) + delay_i[..., None]
+ delay = delay[..., None]
+
+ return torch.special.sinc(n - delay) * _hann(n - delay, 2 * pad)
+
+
+def _adjust_coeff(coeffs: Union[float, torch.Tensor], name: str) -> torch.Tensor:
+ """Validates and converts absorption or scattering parameters to a tensor with appropriate shape
+
+ Args:
+ coeff (float or torch.Tensor): The absorption coefficients of wall materials.
+
+ If the dtype is ``float``, the absorption coefficient is identical for all walls and
+ all frequencies.
+
+ If ``absorption`` is a 1D Tensor, the shape must be `(2*dim,)`,
+ where the values represent absorption coefficients of ``"west"``, ``"east"``,
+ ``"south"``, ``"north"``, ``"floor"``, and ``"ceiling"``, respectively.
+
+ If ``absorption`` is a 2D Tensor, the shape must be `(7, 2*dim)`,
+ where 7 represents the number of octave bands.
+
+ Returns:
+ (torch.Tensor): The expanded coefficient.
+ The shape is `(1, 6)` for single octave band case, and
+ `(7, 6)` for multi octave band case.
+ """
+ num_walls = 6
+ if isinstance(coeffs, float):
+ if coeffs < 0:
+ raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
+ return torch.full((1, num_walls), coeffs)
+ if isinstance(coeffs, Tensor):
+ if torch.any(coeffs < 0):
+ raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
+ if coeffs.ndim == 1:
+ if coeffs.numel() != num_walls:
+ raise ValueError(
+ f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor. "
+ f"Found the shape {coeffs.shape}."
+ )
+ return coeffs.unsqueeze(0)
+ if coeffs.ndim == 2:
+ if coeffs.shape[1] != num_walls:
+ raise ValueError(
+ f"The shape of `{name}` must be (NUM_BANDS, {num_walls}) when it "
+ f"is a 2D Tensor. Found: {coeffs.shape}."
+ )
+ return coeffs
+ raise TypeError(f"`{name}` must be float or Tensor.")
+
+
+def _validate_inputs(
+ room: torch.Tensor,
+ source: torch.Tensor,
+ mic_array: torch.Tensor,
+):
+ """Validate dimensions of input arguments, and normalize different kinds of absorption into the same dimension.
+
+ Args:
+ room (torch.Tensor): The size of the room. width, length (and height)
+ source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(dim,)`.
+ mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, dim)`.
+ """
+ if not (room.ndim == 1 and room.numel() == 3):
+ raise ValueError(f"`room` must be a 1D Tensor with 3 elements. Found {room.shape}.")
+ if not (source.ndim == 1 and source.numel() == 3):
+ raise ValueError(f"`source` must be 1D Tensor with 3 elements. Found {source.shape}.")
+ if not (mic_array.ndim == 2 and mic_array.shape[1] == 3):
+ raise ValueError(f"`mic_array` must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.")
+
+
+def simulate_rir_ism(
+ room: torch.Tensor,
+ source: torch.Tensor,
+ mic_array: torch.Tensor,
+ max_order: int,
+ absorption: Union[float, torch.Tensor],
+ output_length: Optional[int] = None,
+ delay_filter_length: int = 81,
+ center_frequency: Optional[torch.Tensor] = None,
+ sound_speed: float = 343.0,
+ sample_rate: float = 16000.0,
+) -> Tensor:
+ r"""Compute Room Impulse Response (RIR) based on the *image source method* :cite:`allen1979image`.
+ The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`.
+
+ .. devices:: CPU
+
+ .. properties:: TorchScript
+
+ Args:
+ room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents
+ three dimensions of the room.
+ source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`.
+ mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`.
+ max_order (int): The maximum number of reflections of the source.
+ absorption (float or torch.Tensor): The *absorption* :cite:`wiki:Absorption_(acoustics)`
+ coefficients of wall materials for sound energy.
+ If the dtype is ``float``, the absorption coefficient is identical for all walls and
+ all frequencies.
+ If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, where the values represent
+ absorption coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``,
+ and ``"ceiling"``, respectively.
+ If ``absorption`` is a 2D Tensor, the shape must be `(7, 6)`, where 7 represents the number of octave bands.
+ output_length (int or None, optional): The output length of simulated RIR signal. If ``None``,
+ the length is defined as
+
+ .. math::
+ \frac{\text{max\_d} \cdot \text{sample\_rate}}{\text{sound\_speed}} + \text{delay\_filter\_length}
+
+ where ``max_d`` is the maximum distance between image sources and microphones.
+ delay_filter_length (int, optional): The filter length for computing sinc function. (Default: ``81``)
+ center_frequency (torch.Tensor, optional): The center frequencies of octave bands for multi-band walls.
+ Only used when ``absorption`` is a 2D Tensor.
+ sound_speed (float, optional): The speed of sound. (Default: ``343.0``)
+ sample_rate (float, optional): The sample rate of the generated room impulse response signal.
+ (Default: ``16000.0``)
+
+ Returns:
+ (torch.Tensor): The simulated room impulse response waveform. Tensor with dimensions
+ `(channel, rir_length)`.
+
+ Note:
+ If ``absorption`` is a 2D Tensor and ``center_frequency`` is set to ``None``, the center frequencies
+ of octave bands are fixed to ``[125.0, 250.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0]``.
+ Users need to tune the values of ``absorption`` to the corresponding frequencies.
+ """
+ _validate_inputs(room, source, mic_array)
+ absorption = _adjust_coeff(absorption, "absorption")
+ img_location, att = _compute_image_sources(room, source, max_order, absorption)
+
+ # compute distances between image sources and microphones
+ vec = img_location[:, None, :] - mic_array[None, :, :]
+ dist = torch.linalg.norm(vec, dim=-1) # (image_source, channel)
+
+ img_src_att = att[..., None] / dist[None, ...] # (band, image_source, channel)
+
+ # separate delays in integer / frac part
+ delay = dist * sample_rate / sound_speed # distance to delay in samples
+ delay_i = torch.ceil(delay) # integer part
+
+ # compute the shorts IRs corresponding to each image source
+ irs = img_src_att[..., None] * _frac_delay(delay, delay_i, delay_filter_length)[None, ...]
+
+ rir_length = int(delay_i.max() + irs.shape[-1])
+ rir = torch.ops.torchaudio._simulate_rir(irs, delay_i.type(torch.int32), rir_length)
+
+ # multi-band processing
+ if absorption.shape[0] > 1:
+ if center_frequency is None:
+ center = torch.tensor(
+ [125.0, 250.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0], dtype=room.dtype, device=room.device
+ )
+ else:
+ center = center_frequency
+ # n_fft is set to 512 by default.
+ filters = torch.ops.torchaudio._make_rir_filter(center, sample_rate, n_fft=512)
+ rir = torchaudio.functional.fftconvolve(rir, filters.unsqueeze(1).repeat(1, rir.shape[1], 1), mode="same")
+
+ # sum up rir signals of all image sources into one waveform.
+ rir = rir.sum(0)
+
+ if output_length is not None:
+ if output_length > rir.shape[-1]:
+ rir = torch.nn.functional.pad(rir, (0, output_length - rir.shape[-1]), "constant", 0.0)
+ else:
+ rir = rir[..., :output_length]
+
+ return rir
+
+
+def ray_tracing(
+ room: torch.Tensor,
+ source: torch.Tensor,
+ mic_array: torch.Tensor,
+ num_rays: int,
+ absorption: Union[float, torch.Tensor] = 0.0,
+ scattering: Union[float, torch.Tensor] = 0.0,
+ mic_radius: float = 0.5,
+ sound_speed: float = 343.0,
+ energy_thres: float = 1e-7,
+ time_thres: float = 10.0,
+ hist_bin_size: float = 0.004,
+) -> torch.Tensor:
+ r"""Compute energy histogram via ray tracing.
+
+ The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`.
+
+ ``num_rays`` rays are casted uniformly in all directions from the source;
+ when a ray intersects a wall, it is reflected and part of its energy is absorbed.
+ It is also scattered (sent directly to the microphone(s)) according to the ``scattering``
+ coefficient.
+ When a ray is close to the microphone, its current energy is recorded in the output
+ histogram for that given time slot.
+
+ .. devices:: CPU
+
+ .. properties:: TorchScript
+
+ Args:
+ room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents
+ three dimensions of the room.
+ source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`.
+ mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`.
+ absorption (float or torch.Tensor, optional): The absorption coefficients of wall materials.
+ (Default: ``0.0``).
+ If the type is ``float``, the absorption coefficient is identical to all walls and
+ all frequencies.
+ If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, representing absorption
+ coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and
+ ``"ceiling"``, respectively.
+ If ``absorption`` is a 2D Tensor, the shape must be `(num_bands, 6)`.
+ ``num_bands`` is the number of frequency bands (usually 7).
+ scattering(float or torch.Tensor, optional): The scattering coefficients of wall materials. (Default: ``0.0``)
+ The shape and type of this parameter is the same as for ``absorption``.
+ mic_radius(float, optional): The radius of the microphone in meters. (Default: 0.5)
+ sound_speed (float, optional): The speed of sound in meters per second. (Default: ``343.0``)
+ energy_thres (float, optional): The energy level below which we stop tracing a ray. (Default: ``1e-7``)
+ The initial energy of each ray is ``2 / num_rays``.
+ time_thres (float, optional): The maximal duration for which rays are traced. (Unit: seconds) (Default: 10.0)
+ hist_bin_size (float, optional): The size of each bin in the output histogram. (Unit: seconds) (Default: 0.004)
+
+ Returns:
+ (torch.Tensor): The 3D histogram(s) where the energy of the traced ray is recorded.
+ Each bin corresponds to a given time slot.
+ The shape is `(channel, num_bands, num_bins)`, where
+ ``num_bins = ceil(time_thres / hist_bin_size)``.
+ If both ``absorption`` and ``scattering`` are floats, then ``num_bands == 1``.
+ """
+ if time_thres < hist_bin_size:
+ raise ValueError(
+ "`time_thres` must be greater than `hist_bin_size`. "
+ f"Found: hist_bin_size={hist_bin_size}, time_thres={time_thres}."
+ )
+
+ if room.dtype != source.dtype or source.dtype != mic_array.dtype:
+ raise ValueError(
+ "dtype of `room`, `source` and `mic_array` must match. "
+ f"Found: `room` ({room.dtype}), `source` ({source.dtype}) and "
+ f"`mic_array` ({mic_array.dtype})"
+ )
+
+ _validate_inputs(room, source, mic_array)
+ absorption = _adjust_coeff(absorption, "absorption").to(room.dtype)
+ scattering = _adjust_coeff(scattering, "scattering").to(room.dtype)
+
+ # Bring absorption and scattering to the same shape
+ if absorption.shape[0] == 1 and scattering.shape[0] > 1:
+ absorption = absorption.expand(scattering.shape)
+ if scattering.shape[0] == 1 and absorption.shape[0] > 1:
+ scattering = scattering.expand(absorption.shape)
+ if absorption.shape != scattering.shape:
+ raise ValueError(
+ "`absorption` and `scattering` must be broadcastable to the same number of bands and walls. "
+ f"Inferred shapes absorption={absorption.shape} and scattering={scattering.shape}"
+ )
+
+ histograms = torch.ops.torchaudio.ray_tracing(
+ room,
+ source,
+ mic_array,
+ num_rays,
+ absorption,
+ scattering,
+ mic_radius,
+ sound_speed,
+ energy_thres,
+ time_thres,
+ hist_bin_size,
+ )
+
+ return histograms
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/functional/functional.py b/MLPY/Lib/site-packages/torchaudio/prototype/functional/functional.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d14d7af29c5b72249a4b015e6bf6609a6acba78
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/functional/functional.py
@@ -0,0 +1,190 @@
+import math
+import warnings
+from typing import Optional
+
+import torch
+from torchaudio.functional.functional import _create_triangular_filterbank
+
+
+def _hz_to_bark(freqs: float, bark_scale: str = "traunmuller") -> float:
+ r"""Convert Hz to Barks.
+
+ Args:
+ freqs (float): Frequencies in Hz
+ bark_scale (str, optional): Scale to use: ``traunmuller``, ``schroeder`` or ``wang``. (Default: ``traunmuller``)
+
+ Returns:
+ barks (float): Frequency in Barks
+ """
+
+ if bark_scale not in ["schroeder", "traunmuller", "wang"]:
+ raise ValueError('bark_scale should be one of "schroeder", "traunmuller" or "wang".')
+
+ if bark_scale == "wang":
+ return 6.0 * math.asinh(freqs / 600.0)
+ elif bark_scale == "schroeder":
+ return 7.0 * math.asinh(freqs / 650.0)
+ # Traunmuller Bark scale
+ barks = ((26.81 * freqs) / (1960.0 + freqs)) - 0.53
+ # Bark value correction
+ if barks < 2:
+ barks += 0.15 * (2 - barks)
+ elif barks > 20.1:
+ barks += 0.22 * (barks - 20.1)
+
+ return barks
+
+
+def _bark_to_hz(barks: torch.Tensor, bark_scale: str = "traunmuller") -> torch.Tensor:
+ """Convert bark bin numbers to frequencies.
+
+ Args:
+ barks (torch.Tensor): Bark frequencies
+ bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``)
+
+ Returns:
+ freqs (torch.Tensor): Barks converted in Hz
+ """
+
+ if bark_scale not in ["schroeder", "traunmuller", "wang"]:
+ raise ValueError('bark_scale should be one of "traunmuller", "schroeder" or "wang".')
+
+ if bark_scale == "wang":
+ return 600.0 * torch.sinh(barks / 6.0)
+ elif bark_scale == "schroeder":
+ return 650.0 * torch.sinh(barks / 7.0)
+ # Bark value correction
+ if any(barks < 2):
+ idx = barks < 2
+ barks[idx] = (barks[idx] - 0.3) / 0.85
+ elif any(barks > 20.1):
+ idx = barks > 20.1
+ barks[idx] = (barks[idx] + 4.422) / 1.22
+
+ # Traunmuller Bark scale
+ freqs = 1960 * ((barks + 0.53) / (26.28 - barks))
+
+ return freqs
+
+
+def _hz_to_octs(freqs, tuning=0.0, bins_per_octave=12):
+ a440 = 440.0 * 2.0 ** (tuning / bins_per_octave)
+ return torch.log2(freqs / (a440 / 16))
+
+
+def barkscale_fbanks(
+ n_freqs: int,
+ f_min: float,
+ f_max: float,
+ n_barks: int,
+ sample_rate: int,
+ bark_scale: str = "traunmuller",
+) -> torch.Tensor:
+ r"""Create a frequency bin conversion matrix.
+
+ .. devices:: CPU
+
+ .. properties:: TorchScript
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/bark_fbanks.png
+ :alt: Visualization of generated filter bank
+
+ Args:
+ n_freqs (int): Number of frequencies to highlight/apply
+ f_min (float): Minimum frequency (Hz)
+ f_max (float): Maximum frequency (Hz)
+ n_barks (int): Number of mel filterbanks
+ sample_rate (int): Sample rate of the audio waveform
+ bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``)
+
+ Returns:
+ torch.Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_barks``)
+ meaning number of frequencies to highlight/apply to x the number of filterbanks.
+ Each column is a filterbank so that assuming there is a matrix A of
+ size (..., ``n_freqs``), the applied result would be
+ ``A * barkscale_fbanks(A.size(-1), ...)``.
+
+ """
+
+ # freq bins
+ all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
+
+ # calculate bark freq bins
+ m_min = _hz_to_bark(f_min, bark_scale=bark_scale)
+ m_max = _hz_to_bark(f_max, bark_scale=bark_scale)
+
+ m_pts = torch.linspace(m_min, m_max, n_barks + 2)
+ f_pts = _bark_to_hz(m_pts, bark_scale=bark_scale)
+
+ # create filterbank
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
+
+ if (fb.max(dim=0).values == 0.0).any():
+ warnings.warn(
+ "At least one bark filterbank has all zero values. "
+ f"The value for `n_barks` ({n_barks}) may be set too high. "
+ f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
+ )
+
+ return fb
+
+
+def chroma_filterbank(
+ sample_rate: int,
+ n_freqs: int,
+ n_chroma: int,
+ *,
+ tuning: float = 0.0,
+ ctroct: float = 5.0,
+ octwidth: Optional[float] = 2.0,
+ norm: int = 2,
+ base_c: bool = True,
+):
+ """Create a frequency-to-chroma conversion matrix. Implementation adapted from librosa.
+
+ Args:
+ sample_rate (int): Sample rate.
+ n_freqs (int): Number of input frequencies.
+ n_chroma (int): Number of output chroma.
+ tuning (float, optional): Tuning deviation from A440 in fractions of a chroma bin. (Default: 0.0)
+ ctroct (float, optional): Center of Gaussian dominance window to weight filters by, in octaves. (Default: 5.0)
+ octwidth (float or None, optional): Width of Gaussian dominance window to weight filters by, in octaves.
+ If ``None``, then disable weighting altogether. (Default: 2.0)
+ norm (int, optional): order of norm to normalize filter bank by. (Default: 2)
+ base_c (bool, optional): If True, then start filter bank at C. Otherwise, start at A. (Default: True)
+
+ Returns:
+ torch.Tensor: Chroma filter bank, with shape `(n_freqs, n_chroma)`.
+ """
+ # Skip redundant upper half of frequency range.
+ freqs = torch.linspace(0, sample_rate // 2, n_freqs)[1:]
+ freq_bins = n_chroma * _hz_to_octs(freqs, bins_per_octave=n_chroma, tuning=tuning)
+ freq_bins = torch.cat((torch.tensor([freq_bins[0] - 1.5 * n_chroma]), freq_bins))
+ freq_bin_widths = torch.cat(
+ (
+ torch.maximum(freq_bins[1:] - freq_bins[:-1], torch.tensor(1.0)),
+ torch.tensor([1]),
+ )
+ )
+
+ # (n_freqs, n_chroma)
+ D = freq_bins.unsqueeze(1) - torch.arange(0, n_chroma)
+
+ n_chroma2 = round(n_chroma / 2)
+
+ # Project to range [-n_chroma/2, n_chroma/2 - 1]
+ D = torch.remainder(D + n_chroma2, n_chroma) - n_chroma2
+
+ fb = torch.exp(-0.5 * (2 * D / torch.tile(freq_bin_widths.unsqueeze(1), (1, n_chroma))) ** 2)
+ fb = torch.nn.functional.normalize(fb, p=norm, dim=1)
+
+ if octwidth is not None:
+ fb *= torch.tile(
+ torch.exp(-0.5 * (((freq_bins.unsqueeze(1) / n_chroma - ctroct) / octwidth) ** 2)),
+ (1, n_chroma),
+ )
+
+ if base_c:
+ fb = torch.roll(fb, -3 * (n_chroma // 12), dims=1)
+
+ return fb
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/models/__init__.py b/MLPY/Lib/site-packages/torchaudio/prototype/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f2af31c07de6c025b795ba4b07dd7deeb3c3283
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/models/__init__.py
@@ -0,0 +1,36 @@
+from ._conformer_wav2vec2 import (
+ conformer_wav2vec2_base,
+ conformer_wav2vec2_model,
+ conformer_wav2vec2_pretrain_base,
+ conformer_wav2vec2_pretrain_large,
+ conformer_wav2vec2_pretrain_model,
+ ConformerWav2Vec2PretrainModel,
+)
+from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model
+from .conv_emformer import ConvEmformer
+from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder
+from .rnnt import conformer_rnnt_base, conformer_rnnt_biasing, conformer_rnnt_biasing_base, conformer_rnnt_model
+from .rnnt_decoder import Hypothesis, RNNTBeamSearchBiasing
+
+__all__ = [
+ "conformer_rnnt_base",
+ "conformer_rnnt_model",
+ "conformer_rnnt_biasing",
+ "conformer_rnnt_biasing_base",
+ "ConvEmformer",
+ "conformer_wav2vec2_model",
+ "conformer_wav2vec2_base",
+ "conformer_wav2vec2_pretrain_model",
+ "conformer_wav2vec2_pretrain_base",
+ "conformer_wav2vec2_pretrain_large",
+ "ConformerWav2Vec2PretrainModel",
+ "emformer_hubert_base",
+ "emformer_hubert_model",
+ "Hypothesis",
+ "RNNTBeamSearchBiasing",
+ "HiFiGANVocoder",
+ "hifigan_vocoder_v1",
+ "hifigan_vocoder_v2",
+ "hifigan_vocoder_v3",
+ "hifigan_vocoder",
+]
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cea7f4280252456808a6bc5e9b8aedbaae2f8bcd
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/_conformer_wav2vec2.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/_conformer_wav2vec2.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b065dc45a5b33c96a3989beb7792d43b6aaf7c4
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/_conformer_wav2vec2.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/_emformer_hubert.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/_emformer_hubert.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..276156ea340b649771209c646c33040f5179e9df
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/_emformer_hubert.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/conv_emformer.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/conv_emformer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7bba3dd73aebcb430387a73b16170858911e832
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/conv_emformer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/hifi_gan.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/hifi_gan.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0814217ce187d54000ef800af72c68339890b75a
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/hifi_gan.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/rnnt.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/rnnt.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d61ed5a71713a2d8bca15eebbd974abe16c13da
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/rnnt.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/rnnt_decoder.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/rnnt_decoder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..52dd176325ed8e20876a582cdc23eda860f8e70a
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/models/__pycache__/rnnt_decoder.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/models/_conformer_wav2vec2.py b/MLPY/Lib/site-packages/torchaudio/prototype/models/_conformer_wav2vec2.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d079d553cd991c712aea78e7794099747723fda
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/models/_conformer_wav2vec2.py
@@ -0,0 +1,794 @@
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn, Tensor
+from torch.nn import Module, ModuleList
+from torchaudio.models import Wav2Vec2Model
+from torchaudio.models.conformer import ConformerLayer
+from torchaudio.models.rnnt import _TimeReduction
+from torchaudio.models.wav2vec2 import components
+
+
+def _buffered_arange(max) -> Tensor:
+ """Compute arange using a buffered tensor across function calls.
+ Produces same result as torch.arange(end=max).
+
+ Args:
+ max (int): Ending value for arange.
+ """
+ if not hasattr(_buffered_arange, "buf"):
+ _buffered_arange.buf = torch.LongTensor()
+ if max > _buffered_arange.buf.numel():
+ _buffered_arange.buf.resize_(max)
+ torch.arange(max, out=_buffered_arange.buf)
+ return _buffered_arange.buf[:max]
+
+
+def _sample_negatives(input: Tensor, num_negatives: int, cross_sample_negatives: int) -> Tuple[Tensor, Tensor]:
+ """Sample negative examples from masked input.
+
+ Args:
+ input (Tensor): Tensor of dimension `(batch, frame, dim)`.
+ num_negatives (int): Number of negative examples to sample.
+ cross_sample_negatives (int): Number of negative examples to cross sample.
+
+ Returns:
+ (Tensor, Tensor):
+ Tensor
+ The negative samples.
+ Tensor
+ The indices of the negative samples.
+ """
+ if num_negatives == 0 and cross_sample_negatives == 0:
+ return (
+ torch.zeros(0).to(input.device, input.dtype),
+ torch.zeros(0).to(input.device, input.dtype),
+ )
+
+ B, T, D = input.shape
+ input = input.view(-1, D)
+
+ cross_high = T * B
+ high = T
+
+ assert high > 1
+
+ if num_negatives > 0:
+ tszs = _buffered_arange(T).unsqueeze(-1).expand(-1, num_negatives).flatten()
+
+ neg_idxs = torch.randint(low=0, high=high - 1, size=(B, num_negatives * T))
+ neg_idxs[neg_idxs >= tszs] += 1
+
+ if cross_sample_negatives > 0:
+ tszs = _buffered_arange(T).unsqueeze(-1).expand(-1, cross_sample_negatives).flatten()
+
+ cross_neg_idxs = torch.randint(low=0, high=cross_high - 1, size=(B, cross_sample_negatives * T))
+ cross_neg_idxs[cross_neg_idxs >= tszs] += 1
+
+ if num_negatives > 0:
+ neg_idxs = neg_idxs + (torch.arange(B).unsqueeze(1) * high)
+ else:
+ neg_idxs = cross_neg_idxs
+
+ if cross_sample_negatives > 0 and num_negatives > 0:
+ neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1)
+
+ negs = input[neg_idxs.view(-1)]
+ negs = negs.view(B, T, num_negatives + cross_sample_negatives, D).permute(2, 0, 1, 3) # NxBxCxT
+
+ return negs, neg_idxs
+
+
+class NegativeSampler(Module):
+ r"""Applies preprocessing to input and then computes negative sampling.
+
+ Args:
+ preprocessor (nn.Module): Transforms input tensor prior to negative sampling.
+ num_negatives (int): Number of negative examples to sample.
+ cross_sample_negatives (int): Number of negative examples to cross sample.
+ """
+
+ def __init__(
+ self,
+ preprocessor: Module,
+ num_negatives: int,
+ cross_sample_negatives: int,
+ ):
+ super().__init__()
+ self.preprocessor = preprocessor
+ self.num_negatives = num_negatives
+ self.cross_sample_negatives = cross_sample_negatives
+
+ def forward(self, input: Tensor) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
+ """
+ Args:
+ input (Tensor): Tensor of dimension `(B, T, D)`.
+
+ Returns:
+ (Tensor, Tensor, Optional[Tensor]):
+ Tensor
+ The input tensor after preprocessing, prior to being sampled.
+ Tensor
+ The negative samples.
+ Tensor
+ The indices of the negative samples.
+ """
+ preprocessed = self.preprocessor(input)
+ negs, neg_idxs = _sample_negatives(preprocessed, self.num_negatives, self.cross_sample_negatives)
+ return preprocessed, negs, neg_idxs
+
+
+class FeatureEncoder(Module):
+ """Feature Encoder class, consisting of time reduction and linear layer.
+
+ Args:
+ stride (int): Number of frames to merge for the output frame.
+ input_dim (int): Input dimension of the tensor.
+ output_dim (int): Output dimension of the tensor.
+ """
+
+ def __init__(self, input_dim: int, output_dim: int, stride: int):
+ super().__init__()
+ self.time_reduction_layer = _TimeReduction(stride=stride)
+ self.linear_layer = nn.Linear(input_dim * stride, output_dim)
+
+ def forward(
+ self,
+ x: Tensor,
+ lengths: Optional[Tensor],
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ """
+ Args:
+ x (Tensor): Feature Tensor representing log Mel Spectrogram output. shape ``(B, T, D)``.
+ lengths (Tensor or None):
+ Valid length of each input sample. shape: ``(B, )``.
+
+ Returns:
+ (Tensor, Optional[Tensor]):
+ Tensor: output sequence after undergoing time reduction and linear projection.
+ Shape ``(B, T // stride, D * stride).
+ Optional[Tensor]: output lengths of shape ``(B,)`` if lengths parameter is provided,
+ otherwise `None`.
+ """
+ if lengths is None:
+ B, T, D = x.shape
+ dummy_lengths = torch.full((B,), T)
+ x, _ = self.time_reduction_layer(x, dummy_lengths)
+ x = self.linear_layer(x)
+ return x, None
+
+ x, lengths = self.time_reduction_layer(x, lengths)
+ x = self.linear_layer(x)
+ return x, lengths
+
+
+class ConformerEncoder(Module):
+ """Conformer Encoder class, consisting of feature projection and conformer modules.
+
+ Args:
+ feature_projection (nn.Module):
+ Projects feature to encoder dimension.
+ conformer (nn.ModuleList)
+ List of Conformer layers.
+ """
+
+ def __init__(
+ self,
+ feature_projection: Module,
+ conformer: ModuleList,
+ ):
+ super().__init__()
+ self.feature_projection = feature_projection
+ self.conformer = conformer
+
+ def _preprocess(
+ self,
+ features: Tensor,
+ lengths: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ x = self.feature_projection(features)
+ if lengths is not None:
+ mask = components._get_padding_mask(x, lengths)
+ else:
+ mask = None
+ return x, mask
+
+ def _get_intermediate_outputs(
+ self,
+ x: Tensor,
+ mask: Optional[Tensor] = None,
+ num_layers: Optional[int] = None,
+ ) -> List[Tensor]:
+ if num_layers is not None:
+ if not 0 < num_layers <= len(self.conformer):
+ raise ValueError(f"`num_layers` must be between [1, {len(self.conformer)}]")
+
+ ret: List[Tensor] = []
+
+ x = x.transpose(0, 1)
+ for layer in self.conformer:
+ x = layer(x, mask)
+ ret.append(x.transpose(0, 1))
+ if num_layers is not None and len(ret) >= num_layers:
+ return ret
+ return ret
+
+ def forward(
+ self,
+ features: Tensor,
+ lengths: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Args:
+ features (Tensor): Tensor of features of shape ``(B, T, D)``.
+ lengths (Tensor or None, optional): Valid length of each input sample. shape: ``(B, )``.
+
+ Returns:
+ Tensor: result after applying conformer encoder to features.
+ """
+ x, mask = self._preprocess(features, lengths)
+ x = x.transpose(0, 1)
+ for layer in self.conformer:
+ x = layer(x, mask)
+ return x.transpose(0, 1)
+
+ def extract_features(
+ self,
+ features: Tensor,
+ lengths: Optional[Tensor] = None,
+ num_layers: Optional[int] = None,
+ ) -> List[Tensor]:
+ """Returns the list of outputs from the intermediate layers of conformer block in the encoder.
+
+ Args:
+ features (Tensor): Tensor of features of shape ``(B, T, D)``.
+ lengths (Tensor or None, optional): Valid length of each input sample. shape: ``(B, )``.
+
+ Returns:
+ List[Tensor]:
+ Features from requested layers. Each Tensor is of shape: `(batch, time frame, feature dimension)`.
+ """
+ x, masks = self._preprocess(features, lengths)
+ return self._get_intermediate_outputs(x, mask=masks, num_layers=num_layers)
+
+
+class ConformerWav2Vec2PretrainModel(Module):
+ """Conformer Wav2Vec2 pre-train model for training from scratch.
+
+ Note:
+ To build the model, please use one of the factory functions,
+ :py:func:`conformer_wav2vec2_base` or :py:func:`conformer_wav2vec2_large`
+
+ Args:
+ wav2vec2 (nn.Module):
+ Conformer based Wav2Vec2 model, including feature extractor and conformer encoder components.
+ mask_generator (nn.Module):
+ Mask generator that generates the mask for masked prediction during training.
+ negative_sampler (nn.Module):
+ Negative sampler to apply after masking.
+
+ """
+
+ def __init__(
+ self,
+ wav2vec2: Wav2Vec2Model,
+ mask_generator: Module,
+ negative_sampler: Module,
+ ):
+ super().__init__()
+ self.wav2vec2 = wav2vec2
+ self.mask_generator = mask_generator
+ self.negative_sampler = negative_sampler
+
+ def forward(
+ self,
+ features: Tensor,
+ audio_lengths: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor]:
+ """
+ Args:
+ features (Tensor):
+ Tensor of audio features of shape `(batch, frame, dim)`.
+ audio_lengths (Tensor or None, optional):
+ Tensor of valid length of each valid auidio in the batch.
+ shape: `(batch, )` (Default: ``None``)
+
+ Returns:
+ (Tensor, Optional[Tensor], Tensor, Tensor, Tensor, Tensor):
+ Tensor
+ The masked sequences of probability distribution of shape `(batch, frame dim)`.
+ Tensor or None
+ If ``lengths`` argument was provided, a Tensor of shape `(batch, )` representing
+ valid length in time axis is returns.
+ Tensor
+ The mask indices.
+ Tensor
+ The targets, prior to negative sampling.
+ Tensor
+ The negative samples.
+ Tensor
+ The indices of the negative samples.
+ """
+ x, lengths = self.wav2vec2.feature_extractor(features, audio_lengths)
+
+ if lengths is not None:
+ padding_mask = components._get_padding_mask(x, lengths)
+ else:
+ padding_mask = None
+
+ x = self.wav2vec2.encoder.feature_projection.layer_norm(x)
+ x = self.wav2vec2.encoder.feature_projection.dropout(x)
+
+ # Unmasked feature is used to generate positive and negative samples.
+ unmasked_x = x.clone()
+ # Apply masking to x before passing it to Conformer layers.
+ x, mask_idxs = self.mask_generator(x, padding_mask)
+ # Select the frames from masked indices for negative sampling.
+ unmasked_x = unmasked_x[mask_idxs].view(x.shape[0], -1, x.shape[-1])
+ targets, negs, neg_idxs = self.negative_sampler(unmasked_x)
+
+ x = self.wav2vec2.encoder.feature_projection.projection(x)
+ x = x.transpose(0, 1)
+ for conformer_layer in self.wav2vec2.encoder.conformer:
+ x = conformer_layer(x, padding_mask)
+ x = x.transpose(0, 1)
+
+ return x, lengths, mask_idxs, targets, negs, neg_idxs
+
+
+################################################################################
+def _get_conformer_feature_extractor(
+ input_dim: int,
+ output_dim: int,
+ stride: int,
+) -> FeatureEncoder:
+ """Construct Feature Extractor
+
+ Args:
+ input_dim (int): Input dimension of features.
+ output_dim (int): Output dimension after feature extraction.
+ stride (int): Stride used in Time Reduction layer of feature extractor.
+
+ Returns:
+ FeatureEncoder: The resulting feature extraction.
+ """
+ return FeatureEncoder(input_dim, output_dim, stride)
+
+
+def _get_conformer_encoder(
+ in_features: int,
+ embed_dim: int,
+ dropout_input: float,
+ num_layers: int,
+ num_heads: int,
+ ff_interm_features: int,
+ dropout: float,
+ depthwise_conv_kernel_size: Union[int, List[int]],
+ convolution_first: bool,
+ use_group_norm: bool,
+) -> ConformerEncoder:
+ """Construct Conformer Encoder
+
+ Args:
+ in_features (int): The number of input features.
+ embed_dim (int): The dimension of the embedding in the feature projection.
+ dropout_input (float): The dropout probability applied after the input feature
+ is projected to ``embed_dim``.
+ num_layers (int): Number of Conformer layers in the encoder.
+ num_heads (int): Number of heads in each Conformer layer.
+ ff_interm_features (int): Hidden layer dimension of the feedforward network in
+ each Conformer layer.
+ dropout (float): Dropout probability in each Conformer layer.
+ depthwise_conv_kernel_size (int or List[int]): List of kernel sizes corresponding
+ to each of the Conformer layers.If int is provided, all layers will have the
+ same kernel size.
+ convolution_first (bool): Whether to apply the convolution module ahead of the
+ attention module in each Conformer layer.
+ use_group_norm (bool): Whether to use ``GroupNorm`` rather than ``BatchNorm1d`` in
+ the convolution module in each Conformer layer.
+
+ Returns:
+ ConformerEncoder:
+ The resulting conformer encoder module.
+ """
+ feature_projection = components.FeatureProjection(in_features, embed_dim, dropout_input)
+
+ if type(depthwise_conv_kernel_size) == int:
+ depthwise_conv_kernel_size = [depthwise_conv_kernel_size] * num_layers
+
+ assert len(depthwise_conv_kernel_size) == num_layers
+
+ conformer_layers = []
+ for l in range(num_layers):
+ layer = ConformerLayer(
+ input_dim=embed_dim,
+ ffn_dim=ff_interm_features,
+ num_attention_heads=num_heads,
+ depthwise_conv_kernel_size=depthwise_conv_kernel_size[l],
+ dropout=dropout,
+ use_group_norm=use_group_norm,
+ convolution_first=convolution_first,
+ )
+ conformer_layers.append(layer)
+
+ return ConformerEncoder(feature_projection, ModuleList(conformer_layers))
+
+
+def _get_conformer_negativer_sampler(
+ input_dim: int,
+ output_dim: int,
+ num_negatives: int,
+ cross_sample_negatives: int,
+) -> NegativeSampler:
+ """Build custom NegativeSampler module, including linear layer and negative sampling.
+
+ Args:
+ input_dim (int): Dimension of input after feature extraction.
+ output_dim (int): Dimension of embedding for use in negative sampling. Same as the
+ embedding in the feature projection.
+ num_negatives (int): Number of negatives to sample.
+ cross_sample_negatives (int): Number of cross sampled negatives.
+
+ Returns:
+ NegativeSampler:
+ The resulting negative sampler module.
+ """
+ preprocessor = nn.Linear(input_dim, output_dim)
+ return NegativeSampler(preprocessor, num_negatives, cross_sample_negatives)
+
+
+def conformer_wav2vec2_model(
+ extractor_input_dim: int,
+ extractor_output_dim: int,
+ extractor_stride: int,
+ encoder_embed_dim: int,
+ encoder_projection_dropout: float,
+ encoder_num_layers: int,
+ encoder_num_heads: int,
+ encoder_ff_interm_features: int,
+ encoder_depthwise_conv_kernel_size: Union[int, List[int]],
+ encoder_dropout: float,
+ encoder_convolution_first: bool,
+ encoder_use_group_norm: bool,
+) -> Wav2Vec2Model:
+ """Build a custom Conformer Wav2Vec2Model
+
+ Args:
+ extractor_input_dim (int): Input dimension of the features.
+ extractor_output_dim (int): Output dimension after feature extraction.
+ extractor_stride (int): Stride used in time reduction layer of feature extraction.
+ encoder_embed_dim (int): The dimension of the embedding in the feature projection.
+ encoder_projection_dropout (float):
+ The dropout probability applied after the input feature is projected to ``embed_dim``
+ encoder_num_layers (int): Number of Conformer layers in the encoder.
+ encoder_num_heads (int): Number of heads in each Conformer layer.
+ encoder_ff_interm_features (int):
+ Hidden layer dimension of the feedforward network in each Conformer layer.
+ encoder_depthwise_conv_kernel_size (int or List[int]):
+ List of kernel sizes corresponding to each of the Conformer layers.
+ If int is provided, all layers will have the same kernel size.
+ encoder_dropout (float): Dropout probability in each Conformer layer.
+ encoder_convolution_first (bool):
+ Whether to apply the convolution module ahead of the attention module
+ in each Conformer layer.
+ encoder_use_group_norm (bool):
+ Whether to use ``GroupNorm`` rather than ``BatchNorm1d`` in the convolution
+ module in each Conformer layer.
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting wav2vec2 model with a conformer encoder.
+ """
+ feature_extractor = _get_conformer_feature_extractor(
+ extractor_input_dim,
+ extractor_output_dim,
+ extractor_stride,
+ )
+
+ encoder = _get_conformer_encoder(
+ in_features=extractor_output_dim,
+ embed_dim=encoder_embed_dim,
+ dropout_input=encoder_projection_dropout,
+ num_layers=encoder_num_layers,
+ num_heads=encoder_num_heads,
+ ff_interm_features=encoder_ff_interm_features,
+ depthwise_conv_kernel_size=encoder_depthwise_conv_kernel_size,
+ dropout=encoder_dropout,
+ convolution_first=encoder_convolution_first,
+ use_group_norm=encoder_use_group_norm,
+ )
+
+ return Wav2Vec2Model(feature_extractor, encoder)
+
+
+def conformer_wav2vec2_base(
+ extractor_input_dim: int = 64,
+ extractor_output_dim: int = 256,
+ encoder_projection_dropout: float = 0.0,
+) -> Wav2Vec2Model:
+ """
+ Build Conformer Wav2Vec2 Model with "small" architecture from
+ *Conformer-Based Slef-Supervised Learning for Non-Speech Audio Tasks* :cite:`9746490`
+
+ Args:
+ extractor_input_dim (int, optional): Input dimension of feature extractor. (Default: 64)
+ extractor_output_dim (int, optional): Output dimension of feature extractor. (Default: 256)
+ encoder_projection_dropout (float, optional):
+ Dropout probability applied after feature projection. (Default: 0.0)
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting wav2vec2 model with a conformer encoder and ``base`` configuration.
+ """
+ return conformer_wav2vec2_model(
+ extractor_input_dim=extractor_input_dim,
+ extractor_output_dim=extractor_output_dim,
+ extractor_stride=4,
+ encoder_embed_dim=256,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_num_layers=12,
+ encoder_num_heads=8,
+ encoder_ff_interm_features=1024,
+ encoder_depthwise_conv_kernel_size=[31] + [15] * 11,
+ encoder_dropout=0.1,
+ encoder_convolution_first=True,
+ encoder_use_group_norm=True,
+ )
+
+
+def conformer_wav2vec2_pretrain_model(
+ extractor_input_dim: int,
+ extractor_output_dim: int,
+ extractor_stride: int,
+ encoder_embed_dim: int,
+ encoder_projection_dropout: float,
+ encoder_num_layers: int,
+ encoder_num_heads: int,
+ encoder_ff_interm_features: int,
+ encoder_depthwise_conv_kernel_size: int,
+ encoder_dropout: float,
+ encoder_convolution_first: bool,
+ encoder_use_group_norm: bool,
+ mask_prob: float,
+ mask_selection: str,
+ mask_other: float,
+ mask_length: int,
+ no_mask_overlap: bool,
+ mask_min_space: int,
+ mask_channel_prob: float,
+ mask_channel_selection: str,
+ mask_channel_other: float,
+ mask_channel_length: int,
+ no_mask_channel_overlap: bool,
+ mask_channel_min_space: int,
+ num_negatives: int,
+ cross_sample_negatives: int,
+) -> ConformerWav2Vec2PretrainModel:
+ """Build a custom Conformer Wav2Vec2 Model for pre-training
+
+ Args:
+ extractor_input_dim (int): Input dimension of the features.
+ extractor_output_dim (int): Output dimension after feature extraction.
+ extractor_stride (int):
+ Stride used in time reduction layer of feature extraction.
+ encoder_embed_dim (int):
+ The dimension of the embedding in the feature projection.
+ encoder_projection_dropout (float):
+ The dropout probability applied after the input feature is projected to
+ ``embed_dim``
+ encoder_num_layers (int):
+ Number of Conformer layers in the encoder.
+ encoder_num_heads (int):
+ Number of heads in each Conformer layer.
+ encoder_ff_interm_features (int):
+ Hidden layer dimension of the feedforward network in each Conformer layer.
+ encoder_depthwise_conv_kernel_size (int or List[int]):
+ List of kernel sizes corresponding to each of the Conformer layers.
+ If int is provided, all layers will have the same kernel size.
+ encoder_dropout (float):
+ Dropout probability in each Conformer layer.
+ encoder_convolution_first (bool):
+ Whether to apply the convolution module ahead of the attention module
+ in each Conformer layer.
+ encoder_use_group_norm (bool):
+ Whether to use ``GroupNorm`` rather than ``BatchNorm1d`` in the convolution
+ module in each Conformer layer.
+ mask_prob (float):
+ Probability for each token to be chosen as start of the span to be masked.
+ mask_selection (str)
+ How to choose the mask length. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
+ mask_other (float):
+ Secondary mask argument (used for more complex distributions).
+ mask_length (int):
+ The lengths of the mask.
+ no_mask_overlap (bool):
+ Whether to allow masks to overlap.
+ mask_min_space (int):
+ Minimum space between spans (if no overlap is enabled).
+ mask_channel_prob: (float):
+ The probability of replacing a feature with 0.
+ mask_channel_selection (str):
+ How to choose the mask length for channel masking.
+ Options: [``static``, ``uniform``, ``normal``, ``poisson``].
+ mask_channel_other (float):
+ Secondary mask argument for channel masking (used for more complex distributions).
+ mask_channel_length (int):
+ Minimum space between spans (if no overlap is enabled) for channel masking.
+ no_mask_channel_overlap (bool):
+ Whether to allow channel masks to overlap.
+ mask_channel_min_space (int):
+ Minimum space between spans for channel masking (if no overlap is enabled).
+ num_negatives (int):
+ Number of negatives to sample.
+ cross_sample_negatives (int):
+ Number of cross sampled negatives.
+
+ Returns:
+ ConformerWav2Vec2PretrainModel:
+ The resulting model.
+ """
+ wav2vec2 = conformer_wav2vec2_model(
+ extractor_input_dim,
+ extractor_output_dim,
+ extractor_stride,
+ encoder_embed_dim,
+ encoder_projection_dropout,
+ encoder_num_layers,
+ encoder_num_heads,
+ encoder_ff_interm_features,
+ encoder_depthwise_conv_kernel_size,
+ encoder_dropout,
+ encoder_convolution_first,
+ encoder_use_group_norm,
+ )
+
+ mask_generator = components.MaskGenerator(
+ extractor_output_dim,
+ mask_prob,
+ mask_selection,
+ mask_other,
+ mask_length,
+ no_mask_overlap,
+ mask_min_space,
+ mask_channel_prob,
+ mask_channel_selection,
+ mask_channel_other,
+ mask_channel_length,
+ no_mask_channel_overlap,
+ mask_channel_min_space,
+ )
+
+ negative_sampler = _get_conformer_negativer_sampler(
+ extractor_output_dim,
+ encoder_embed_dim,
+ num_negatives,
+ cross_sample_negatives,
+ )
+
+ return ConformerWav2Vec2PretrainModel(
+ wav2vec2=wav2vec2,
+ mask_generator=mask_generator,
+ negative_sampler=negative_sampler,
+ )
+
+
+def conformer_wav2vec2_pretrain_base(
+ extractor_input_dim: int = 64,
+ extractor_output_dim: int = 256,
+ encoder_projection_dropout: float = 0.0,
+ mask_prob: float = 0.3,
+ mask_length: int = 3,
+ num_negatives: int = 100,
+ cross_sample_negatives: int = 0,
+) -> ConformerWav2Vec2PretrainModel:
+ """Build Conformer Wav2Vec2 Model for pre-training with "small" architecture from
+ *Conformer-Based Self-Supervised Learning for Non-Speech Audio Tasks* :cite:`9746490`
+
+ Args:
+ extractor_input_dim (int, optional): Input dimension of the features. (Default: 64)
+ extractor_output_dim (int, optional): Output dimension after feature extraction. (Default: 256)
+ encoder_projection_dropout (float, optional):
+ The dropout probability applied after the input feature is projected to
+ ``embed_dim``. (Default: 0.0)
+ mask_prob (float, optional):
+ Probability for each token to be chosen as start of the span to be masked. (Default: 0.3)
+ mask_length (int, optional):
+ The lengths of the mask. (Default: 3)
+ num_negatives (int, optional):
+ Number of sampled negatives. (Default: 0)
+ cross_sample_negatives (int, optional):
+ Number of cross sampled negatives. (Default: 0)
+
+ Returns:
+ ConformerWav2Vec2PretrainModel:
+ The resulting model.
+ """
+ return conformer_wav2vec2_pretrain_model(
+ extractor_input_dim=extractor_input_dim,
+ extractor_output_dim=extractor_output_dim,
+ extractor_stride=4,
+ encoder_embed_dim=256,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_num_layers=12,
+ encoder_num_heads=8,
+ encoder_ff_interm_features=1024,
+ encoder_depthwise_conv_kernel_size=[31] + [15] * 11,
+ encoder_dropout=0.1,
+ encoder_convolution_first=True,
+ encoder_use_group_norm=True,
+ mask_prob=mask_prob,
+ mask_selection="static",
+ mask_other=0.0,
+ mask_length=mask_length,
+ no_mask_overlap=False,
+ mask_min_space=0,
+ mask_channel_prob=0,
+ mask_channel_selection="static",
+ mask_channel_other=0,
+ mask_channel_length=10,
+ no_mask_channel_overlap=False,
+ mask_channel_min_space=1,
+ num_negatives=num_negatives,
+ cross_sample_negatives=cross_sample_negatives,
+ )
+
+
+def conformer_wav2vec2_pretrain_large(
+ extractor_input_dim: int = 64,
+ extractor_output_dim: int = 256,
+ encoder_projection_dropout: float = 0.0,
+ mask_prob: float = 0.3,
+ mask_length: int = 3,
+ num_negatives: int = 100,
+ cross_sample_negatives: int = 0,
+) -> ConformerWav2Vec2PretrainModel:
+ """Build Conformer Wav2Vec2 Model for pre-training with "large" architecture from
+ *Conformer-Based Slef-Supervised Learning for Non-Speech Audio Tasks* :cite:`9746490`
+
+ Args:
+ extractor_input_dim (int, optional): Input dimension of the features. (Default: 64)
+ extractor_output_dim (int, optional): Output dimension after feature extraction. (Default: 256)
+ encoder_projection_dropout (float, optional):
+ The dropout probability applied after the input feature is projected to
+ ``embed_dim``. (Default: 0.0)
+ mask_prob (float, optional):
+ Probability for each token to be chosen as start of the span to be masked. (Default: 0.3)
+ mask_length (int, optional):
+ The lengths of the mask. (Default: 3)
+ num_negatives (int, optional):
+ Number of sampled negatives. (Default: 0)
+ cross_sample_negatives (int, optional):
+ Number of cross sampled negatives. (Default: 0)
+
+ Returns:
+ ConformerWav2Vec2PretrainModel:
+ The resulting model.
+ """
+ return conformer_wav2vec2_pretrain_model(
+ extractor_input_dim=extractor_input_dim,
+ extractor_output_dim=extractor_output_dim,
+ extractor_stride=4,
+ encoder_embed_dim=768,
+ encoder_projection_dropout=encoder_projection_dropout,
+ encoder_num_layers=12,
+ encoder_num_heads=12,
+ encoder_ff_interm_features=1024,
+ encoder_depthwise_conv_kernel_size=[31] + [15] * 11,
+ encoder_dropout=0.1,
+ encoder_convolution_first=True,
+ encoder_use_group_norm=True,
+ mask_prob=mask_prob,
+ mask_selection="static",
+ mask_other=0.0,
+ mask_length=mask_length,
+ no_mask_overlap=False,
+ mask_min_space=0,
+ mask_channel_prob=0,
+ mask_channel_selection="static",
+ mask_channel_other=0,
+ mask_channel_length=10,
+ no_mask_channel_overlap=False,
+ mask_channel_min_space=1,
+ num_negatives=num_negatives,
+ cross_sample_negatives=cross_sample_negatives,
+ )
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/models/_emformer_hubert.py b/MLPY/Lib/site-packages/torchaudio/prototype/models/_emformer_hubert.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdf13761bcbe40f16cf04207a14abd606edd64ef
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/models/_emformer_hubert.py
@@ -0,0 +1,333 @@
+from typing import List, Optional, Tuple
+
+import torch
+from torchaudio.models import Wav2Vec2Model
+from torchaudio.models.emformer import Emformer
+from torchaudio.models.rnnt import _TimeReduction
+
+
+class FeatureEncoder(torch.nn.Module):
+ """Extract features from log-mel spectrogram input. Consists of linear layer and time reduction layer.
+
+ Args:
+ input_dim (int): The feature dimension of log-mel spectrogram feature.
+ output_dim (int): The feature dimension after linear layer.
+ use_bias (bool): If ``True``, enable bias parameter in the linear layer.
+ stride (int): Number of frames to merge for the output frame.
+ """
+
+ def __init__(self, input_dim: int, output_dim: int, use_bias: bool, stride: int):
+ super().__init__()
+ self.linear = torch.nn.Linear(input_dim, output_dim, bias=use_bias)
+ self.time_reduction = _TimeReduction(stride)
+
+ def forward(
+ self, input: torch.Tensor, lengths: Optional[torch.Tensor]
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """
+ Args:
+ input (torch.Tensor): The log-mel spectrogram input.
+ Tensor with dimensions `(batch, time, input_dim)`.
+ lengths (torch.Tensor or None): Valid length of each input sample.
+ Tensor with dimension `(batch, )`.
+
+ Returns:
+ (torch.Tensor, torch.Tensor or None):
+ torch.Tensor
+ Returned feature Tensor after linear layer and time reduction layer.
+ Tensor with dimensions `(batch, time // stride, output_dim)`.
+ torch.Tensor or None
+ The reduced lengths Tensor.
+ """
+ output = self.linear(input)
+ if lengths is None:
+ B, T, _ = input.shape
+ dummy_lengths = torch.full((B,), T)
+ output, _ = self.time_reduction(output, dummy_lengths)
+ else:
+ output, lengths = self.time_reduction(output, lengths)
+ return output, lengths
+
+
+class EmformerEncoder(torch.nn.Module):
+ """Emformer Encoder class for HuBERT pre-training. Consists of emformer module,
+ linear layer and layer normalization layer.
+
+ Args:
+ emformer (torch.nn.Module):
+ :py:class:`torchaudio.models.Emformer` module that consists of a list of emformer layers.
+ output_linear (torch.nn.Module):
+ Linear layer after emformer module.
+ layer_norm (torch.nn.Module):
+ Apply layer normalization to the output.
+ """
+
+ def __init__(
+ self,
+ emformer: torch.nn.Module,
+ output_linear: torch.nn.Module,
+ layer_norm: torch.nn.Module,
+ ):
+ super().__init__()
+ self.emformer = emformer
+ self.output_linear = output_linear
+ self.layer_norm = layer_norm
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ lengths: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ """
+ Args:
+ input (torch.Tensor): The input feature for emformer encoder.
+ Tensor with dimensions `(batch, time, feature_dim)`.
+ lengths (torch.Tensor or None): Valid length of each input sample.
+ Tensor with dimension `(batch, )`.
+
+ Returns:
+ torch.Tensor: The feature Tensor after emformer encoder.
+ """
+ if lengths is None:
+ B, T, _ = input.shape
+ dummy_lengths = torch.full((B,), T)
+ output, _ = self.emformer(input, dummy_lengths)
+ else:
+ output, lengths = self.emformer(input, lengths)
+ output = self.output_linear(output)
+ output = self.layer_norm(output)
+ return output
+
+ def extract_features(
+ self,
+ input: torch.Tensor,
+ lengths: Optional[torch.Tensor],
+ num_layers: Optional[int] = None,
+ ) -> List[torch.Tensor]:
+ """Extract output Tensors of the emformer layers.
+
+ Args:
+ input (torch.Tensor): The input feature for emformer encoder.
+ Tensor with dimensions `(batch, time, feature_dim)`.
+ lengths (torch.Tensor or None): Valid length of each input sample.
+ Tensor with dimension `(batch, )`.
+ num_layers (int or None, optional): If not ``None``, returns the first
+ `num_layers` layers of Tensors as the output, otherwise returns the
+ Tensors from all emformer layers.
+
+ Returns:
+ List[torch.Tensor]:
+ Output Tensors of selected emformer layers.
+ """
+ if num_layers is not None:
+ if not 0 < num_layers <= len(self.emformer.emformer_layers):
+ raise ValueError(f"`num_layers` must be between [1, {len(self.emformer.emformer_layers)}]")
+
+ ret: List[torch.Tensor] = []
+
+ input = input.permute(1, 0, 2)
+ right_context = self.emformer._gen_right_context(input)
+ utterance = input[: input.size(0) - self.emformer.right_context_length]
+ attention_mask = self.emformer._gen_attention_mask(utterance)
+ mems = (
+ self.emformer.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1]
+ if self.emformer.use_mem
+ else torch.empty(0).to(dtype=input.dtype, device=input.device)
+ )
+ output = utterance
+ if lengths is None:
+ B, T, _ = input.shape
+ lengths = torch.full((B,), T)
+ for layer in self.emformer.emformer_layers:
+ output, right_context, mems = layer(output, lengths, right_context, mems, attention_mask)
+ ret.append(output.permute(1, 0, 2))
+ if num_layers is not None and len(ret) >= num_layers:
+ return ret
+ return ret
+
+
+def _get_emformer_feature_extractor(input_dim: int, output_dim: int, use_bias: bool, stride: int) -> FeatureEncoder:
+ """Construct FeatureEncoder for emformer model.
+
+ Args:
+ input_dim (int): The feature dimension of log-mel spectrogram feature.
+ output_dim (int): The feature dimension after linear layer.
+ use_bias (bool): If ``True``, enable bias parameter in the linear layer.
+ stride (int): Number of frames to merge for the output frame.
+
+ Returns:
+ FeatureEncoder: The resulting FeatureEncoder module.
+ """
+ return FeatureEncoder(input_dim, output_dim, use_bias, stride)
+
+
+def _get_emformer_encoder(
+ input_dim: int,
+ output_dim: int,
+ num_heads: int,
+ ffn_dim: int,
+ num_layers: int,
+ segment_length: int,
+ left_context_length: int,
+ right_context_length: int,
+ dropout: float,
+ activation: str,
+ max_memory_size: int,
+ weight_init_scale_strategy: Optional[str],
+ tanh_on_mem: bool,
+) -> EmformerEncoder:
+ """Construct EmformerEncoder for emformer model.
+
+ Args:
+ input_dim (int): The feature dimension of input Tensor.
+ output_dim (int): The feature dimension after EmformerEncoder.
+ num_heads (int): Number of attention heads in each Emformer layer.
+ ffn_dim: (int): Hidden layer dimension of feedforward network.
+ num_layers (int): Number of Emformer layers to instantiate.
+ segment_length (int): Length of each input segment.
+ left_context_length (int): Length of left context.
+ right_context_length (int): Length of right context.
+ dropout (float): Dropout probability.
+ activation (str): Activation function to use in each Emformer layer's
+ feedforward network. Must be one of ("relu", "gelu", "silu").
+ max_memory_size (int): Maximum number of memory elements to use.
+ weight_init_scale_strategy (str or None): Per-layer weight initialization scaling
+ strategy. Must be one of ("depthwise", "constant", ``None``).
+ tanh_on_mem (bool): If ``True``, applies tanh to memory elements.
+
+ Returns:
+ EmformerEncoder: The resulting EmformerEncoder module.
+ """
+ emformer = Emformer(
+ input_dim=input_dim,
+ num_heads=num_heads,
+ ffn_dim=ffn_dim,
+ num_layers=num_layers,
+ segment_length=segment_length,
+ left_context_length=left_context_length,
+ right_context_length=right_context_length,
+ dropout=dropout,
+ activation=activation,
+ max_memory_size=max_memory_size,
+ weight_init_scale_strategy=weight_init_scale_strategy,
+ tanh_on_mem=tanh_on_mem,
+ )
+ output_linear = torch.nn.Linear(input_dim, output_dim)
+ layer_norm = torch.nn.LayerNorm(output_dim)
+ return EmformerEncoder(emformer, output_linear, layer_norm)
+
+
+def emformer_hubert_model(
+ extractor_input_dim: int,
+ extractor_output_dim: int,
+ extractor_use_bias: bool,
+ extractor_stride: int,
+ encoder_input_dim: int,
+ encoder_output_dim: int,
+ encoder_num_heads: int,
+ encoder_ffn_dim: int,
+ encoder_num_layers: int,
+ encoder_segment_length: int,
+ encoder_left_context_length: int,
+ encoder_right_context_length: int,
+ encoder_dropout: float,
+ encoder_activation: str,
+ encoder_max_memory_size: int,
+ encoder_weight_init_scale_strategy: Optional[str],
+ encoder_tanh_on_mem: bool,
+ aux_num_out: Optional[int],
+) -> Wav2Vec2Model:
+ """Build a custom Emformer HuBERT model.
+
+ Args:
+ extractor_input_dim (int): The input dimension for feature extractor.
+ extractor_output_dim (int): The output dimension after feature extractor.
+ extractor_use_bias (bool): If ``True``, enable bias parameter in the linear layer of feature extractor.
+ extractor_stride (int): Number of frames to merge for the output frame in feature extractor.
+ encoder_input_dim (int): The input dimension for Emformer layer.
+ encoder_output_dim (int): The output dimension after EmformerEncoder.
+ encoder_num_heads (int): Number of attention heads in each Emformer layer.
+ encoder_ffn_dim (int): Hidden layer dimension of feedforward network in Emformer.
+ encoder_num_layers (int): Number of Emformer layers to instantiate.
+ encoder_segment_length (int): Length of each input segment.
+ encoder_left_context_length (int): Length of left context.
+ encoder_right_context_length (int): Length of right context.
+ encoder_dropout (float): Dropout probability.
+ encoder_activation (str): Activation function to use in each Emformer layer's
+ feedforward network. Must be one of ("relu", "gelu", "silu").
+ encoder_max_memory_size (int): Maximum number of memory elements to use.
+ encoder_weight_init_scale_strategy (str or None): Per-layer weight initialization scaling
+ strategy. Must be one of ("depthwise", "constant", ``None``).
+ encoder_tanh_on_mem (bool): If ``True``, applies tanh to memory elements.
+ aux_num_out (int or None):
+ When provided, attach an extra linear layer on top of encoder, which can be
+ used for fine-tuning.
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting :py:class:`torchaudio.models.Wav2Vec2Model` model
+ with a :py:class:`torchaudio.models.Emformer` encoder.
+ """
+ feature_extractor = _get_emformer_feature_extractor(
+ extractor_input_dim, extractor_output_dim, extractor_use_bias, extractor_stride
+ )
+ emformer = _get_emformer_encoder(
+ encoder_input_dim,
+ encoder_output_dim,
+ encoder_num_heads,
+ encoder_ffn_dim,
+ encoder_num_layers,
+ encoder_segment_length,
+ encoder_left_context_length,
+ encoder_right_context_length,
+ encoder_dropout,
+ encoder_activation,
+ encoder_max_memory_size,
+ encoder_weight_init_scale_strategy,
+ encoder_tanh_on_mem,
+ )
+ aux = None
+ if aux_num_out is not None:
+ aux = torch.nn.Linear(in_features=encoder_output_dim, out_features=aux_num_out)
+ return Wav2Vec2Model(feature_extractor, emformer, aux)
+
+
+def emformer_hubert_base(
+ extractor_input_dim: int = 80,
+ extractor_output_dim: int = 128,
+ encoder_dropout: float = 0.1,
+ aux_num_out: Optional[int] = None,
+) -> Wav2Vec2Model:
+ """Build Emformer HuBERT Model with 20 Emformer layers.
+
+ Args:
+ extractor_input_dim (int, optional): The input dimension for feature extractor. (Default: 80)
+ extractor_output_dim (int, optional): The output dimension after feature extractor. (Default: 128)
+ encoder_dropout (float, optional): Dropout probability in Emformer. (Default: 0.1)
+ aux_num_out (int or None, optional): Output dimension of aux layer for fine-tuning. (Default: ``None``)
+
+ Returns:
+ Wav2Vec2Model:
+ The resulting :py:class:`torchaudio.models.Wav2Vec2Model` model
+ with a :py:class:`torchaudio.models.Emformer` encoder.
+ """
+ return emformer_hubert_model(
+ extractor_input_dim=extractor_input_dim,
+ extractor_output_dim=extractor_output_dim,
+ extractor_use_bias=False,
+ extractor_stride=4,
+ encoder_input_dim=512,
+ encoder_output_dim=1024,
+ encoder_num_heads=8,
+ encoder_ffn_dim=2048,
+ encoder_num_layers=20,
+ encoder_segment_length=4,
+ encoder_left_context_length=30,
+ encoder_right_context_length=1,
+ encoder_dropout=encoder_dropout,
+ encoder_activation="gelu",
+ encoder_max_memory_size=0,
+ encoder_weight_init_scale_strategy="depthwise",
+ encoder_tanh_on_mem=True,
+ aux_num_out=aux_num_out,
+ )
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/models/conv_emformer.py b/MLPY/Lib/site-packages/torchaudio/prototype/models/conv_emformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..75a1e474c909d984e03edd44409a8a161747a637
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/models/conv_emformer.py
@@ -0,0 +1,525 @@
+import math
+from typing import List, Optional, Tuple
+
+import torch
+from torchaudio.models.emformer import _EmformerAttention, _EmformerImpl, _get_weight_init_gains
+
+
+def _get_activation_module(activation: str) -> torch.nn.Module:
+ if activation == "relu":
+ return torch.nn.ReLU()
+ elif activation == "gelu":
+ return torch.nn.GELU()
+ elif activation == "silu":
+ return torch.nn.SiLU()
+ else:
+ raise ValueError(f"Unsupported activation {activation}")
+
+
+class _ResidualContainer(torch.nn.Module):
+ def __init__(self, module: torch.nn.Module, output_weight: int):
+ super().__init__()
+ self.module = module
+ self.output_weight = output_weight
+
+ def forward(self, input: torch.Tensor):
+ output = self.module(input)
+ return output * self.output_weight + input
+
+
+class _ConvolutionModule(torch.nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ segment_length: int,
+ right_context_length: int,
+ kernel_size: int,
+ activation: str = "silu",
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+ self.input_dim = input_dim
+ self.segment_length = segment_length
+ self.right_context_length = right_context_length
+ self.state_size = kernel_size - 1
+
+ self.pre_conv = torch.nn.Sequential(
+ torch.nn.LayerNorm(input_dim), torch.nn.Linear(input_dim, 2 * input_dim, bias=True), torch.nn.GLU()
+ )
+ self.conv = torch.nn.Conv1d(
+ in_channels=input_dim,
+ out_channels=input_dim,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=0,
+ groups=input_dim,
+ )
+ self.post_conv = torch.nn.Sequential(
+ torch.nn.LayerNorm(input_dim),
+ _get_activation_module(activation),
+ torch.nn.Linear(input_dim, input_dim, bias=True),
+ torch.nn.Dropout(p=dropout),
+ )
+
+ def _split_right_context(self, utterance: torch.Tensor, right_context: torch.Tensor) -> torch.Tensor:
+ T, B, D = right_context.size()
+ if T % self.right_context_length != 0:
+ raise ValueError("Tensor length should be divisible by its right context length")
+ num_segments = T // self.right_context_length
+ # (num_segments, right context length, B, D)
+ right_context_segments = right_context.reshape(num_segments, self.right_context_length, B, D)
+ right_context_segments = right_context_segments.permute(0, 2, 1, 3).reshape(
+ num_segments * B, self.right_context_length, D
+ )
+
+ pad_segments = [] # [(kernel_size - 1, B, D), ...]
+ for seg_idx in range(num_segments):
+ end_idx = min(self.state_size + (seg_idx + 1) * self.segment_length, utterance.size(0))
+ start_idx = end_idx - self.state_size
+ pad_segments.append(utterance[start_idx:end_idx, :, :])
+
+ pad_segments = torch.cat(pad_segments, dim=1).permute(1, 0, 2) # (num_segments * B, kernel_size - 1, D)
+ return torch.cat([pad_segments, right_context_segments], dim=1).permute(0, 2, 1)
+
+ def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor:
+ # (num_segments * B, D, right_context_length)
+ right_context = right_context.reshape(-1, B, self.input_dim, self.right_context_length)
+ right_context = right_context.permute(0, 3, 1, 2)
+ return right_context.reshape(-1, B, self.input_dim) # (right_context_length * num_segments, B, D)
+
+ def forward(
+ self, utterance: torch.Tensor, right_context: torch.Tensor, state: Optional[torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ input = torch.cat((right_context, utterance)) # input: (T, B, D)
+ x = self.pre_conv(input)
+ x_right_context, x_utterance = x[: right_context.size(0), :, :], x[right_context.size(0) :, :, :]
+ x_utterance = x_utterance.permute(1, 2, 0) # (B, D, T_utterance)
+
+ if state is None:
+ state = torch.zeros(
+ input.size(1),
+ input.size(2),
+ self.state_size,
+ device=input.device,
+ dtype=input.dtype,
+ ) # (B, D, T)
+ state_x_utterance = torch.cat([state, x_utterance], dim=2)
+
+ conv_utterance = self.conv(state_x_utterance) # (B, D, T_utterance)
+ conv_utterance = conv_utterance.permute(2, 0, 1)
+
+ if self.right_context_length > 0:
+ # (B * num_segments, D, right_context_length + kernel_size - 1)
+ right_context_block = self._split_right_context(state_x_utterance.permute(2, 0, 1), x_right_context)
+ conv_right_context_block = self.conv(right_context_block) # (B * num_segments, D, right_context_length)
+ # (T_right_context, B, D)
+ conv_right_context = self._merge_right_context(conv_right_context_block, input.size(1))
+ y = torch.cat([conv_right_context, conv_utterance], dim=0)
+ else:
+ y = conv_utterance
+
+ output = self.post_conv(y) + input
+ new_state = state_x_utterance[:, :, -self.state_size :]
+ return output[right_context.size(0) :], output[: right_context.size(0)], new_state
+
+ def infer(
+ self, utterance: torch.Tensor, right_context: torch.Tensor, state: Optional[torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ input = torch.cat((utterance, right_context))
+ x = self.pre_conv(input) # (T, B, D)
+ x = x.permute(1, 2, 0) # (B, D, T)
+
+ if state is None:
+ state = torch.zeros(
+ input.size(1),
+ input.size(2),
+ self.state_size,
+ device=input.device,
+ dtype=input.dtype,
+ ) # (B, D, T)
+ state_x = torch.cat([state, x], dim=2)
+ conv_out = self.conv(state_x)
+ conv_out = conv_out.permute(2, 0, 1) # T, B, D
+ output = self.post_conv(conv_out) + input
+ new_state = state_x[:, :, -self.state_size - right_context.size(0) : -right_context.size(0)]
+ return output[: utterance.size(0)], output[utterance.size(0) :], new_state
+
+
+class _ConvEmformerLayer(torch.nn.Module):
+ r"""Convolution-augmented Emformer layer that constitutes ConvEmformer.
+
+ Args:
+ input_dim (int): input dimension.
+ num_heads (int): number of attention heads.
+ ffn_dim: (int): hidden layer dimension of feedforward network.
+ segment_length (int): length of each input segment.
+ kernel_size (int): size of kernel to use in convolution module.
+ dropout (float, optional): dropout probability. (Default: 0.0)
+ ffn_activation (str, optional): activation function to use in feedforward network.
+ Must be one of ("relu", "gelu", "silu"). (Default: "relu")
+ left_context_length (int, optional): length of left context. (Default: 0)
+ right_context_length (int, optional): length of right context. (Default: 0)
+ max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
+ weight_init_gain (float or None, optional): scale factor to apply when initializing
+ attention module parameters. (Default: ``None``)
+ tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
+ negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
+ conv_activation (str, optional): activation function to use in convolution module.
+ Must be one of ("relu", "gelu", "silu"). (Default: "silu")
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ num_heads: int,
+ ffn_dim: int,
+ segment_length: int,
+ kernel_size: int,
+ dropout: float = 0.0,
+ ffn_activation: str = "relu",
+ left_context_length: int = 0,
+ right_context_length: int = 0,
+ max_memory_size: int = 0,
+ weight_init_gain: Optional[float] = None,
+ tanh_on_mem: bool = False,
+ negative_inf: float = -1e8,
+ conv_activation: str = "silu",
+ ):
+ super().__init__()
+ # TODO: implement talking heads attention.
+ self.attention = _EmformerAttention(
+ input_dim=input_dim,
+ num_heads=num_heads,
+ dropout=dropout,
+ weight_init_gain=weight_init_gain,
+ tanh_on_mem=tanh_on_mem,
+ negative_inf=negative_inf,
+ )
+ self.dropout = torch.nn.Dropout(dropout)
+ self.memory_op = torch.nn.AvgPool1d(kernel_size=segment_length, stride=segment_length, ceil_mode=True)
+
+ activation_module = _get_activation_module(ffn_activation)
+ self.ffn0 = _ResidualContainer(
+ torch.nn.Sequential(
+ torch.nn.LayerNorm(input_dim),
+ torch.nn.Linear(input_dim, ffn_dim),
+ activation_module,
+ torch.nn.Dropout(dropout),
+ torch.nn.Linear(ffn_dim, input_dim),
+ torch.nn.Dropout(dropout),
+ ),
+ 0.5,
+ )
+ self.ffn1 = _ResidualContainer(
+ torch.nn.Sequential(
+ torch.nn.LayerNorm(input_dim),
+ torch.nn.Linear(input_dim, ffn_dim),
+ activation_module,
+ torch.nn.Dropout(dropout),
+ torch.nn.Linear(ffn_dim, input_dim),
+ torch.nn.Dropout(dropout),
+ ),
+ 0.5,
+ )
+ self.layer_norm_input = torch.nn.LayerNorm(input_dim)
+ self.layer_norm_output = torch.nn.LayerNorm(input_dim)
+
+ self.conv = _ConvolutionModule(
+ input_dim=input_dim,
+ kernel_size=kernel_size,
+ activation=conv_activation,
+ dropout=dropout,
+ segment_length=segment_length,
+ right_context_length=right_context_length,
+ )
+
+ self.left_context_length = left_context_length
+ self.segment_length = segment_length
+ self.max_memory_size = max_memory_size
+ self.input_dim = input_dim
+ self.kernel_size = kernel_size
+ self.use_mem = max_memory_size > 0
+
+ def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[torch.Tensor]:
+ empty_memory = torch.zeros(self.max_memory_size, batch_size, self.input_dim, device=device)
+ left_context_key = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
+ left_context_val = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
+ past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
+ conv_cache = torch.zeros(
+ batch_size,
+ self.input_dim,
+ self.kernel_size - 1,
+ device=device,
+ )
+ return [empty_memory, left_context_key, left_context_val, past_length, conv_cache]
+
+ def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ past_length = state[3][0][0].item()
+ past_left_context_length = min(self.left_context_length, past_length)
+ past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
+ pre_mems = state[0][self.max_memory_size - past_mem_length :]
+ lc_key = state[1][self.left_context_length - past_left_context_length :]
+ lc_val = state[2][self.left_context_length - past_left_context_length :]
+ conv_cache = state[4]
+ return pre_mems, lc_key, lc_val, conv_cache
+
+ def _pack_state(
+ self,
+ next_k: torch.Tensor,
+ next_v: torch.Tensor,
+ update_length: int,
+ mems: torch.Tensor,
+ conv_cache: torch.Tensor,
+ state: List[torch.Tensor],
+ ) -> List[torch.Tensor]:
+ new_k = torch.cat([state[1], next_k])
+ new_v = torch.cat([state[2], next_v])
+ state[0] = torch.cat([state[0], mems])[-self.max_memory_size :]
+ state[1] = new_k[new_k.shape[0] - self.left_context_length :]
+ state[2] = new_v[new_v.shape[0] - self.left_context_length :]
+ state[3] = state[3] + update_length
+ state[4] = conv_cache
+ return state
+
+ def _apply_pre_attention(
+ self, utterance: torch.Tensor, right_context: torch.Tensor, summary: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ x = torch.cat([right_context, utterance, summary])
+ ffn0_out = self.ffn0(x)
+ layer_norm_input_out = self.layer_norm_input(ffn0_out)
+ layer_norm_input_right_context, layer_norm_input_utterance, layer_norm_input_summary = (
+ layer_norm_input_out[: right_context.size(0)],
+ layer_norm_input_out[right_context.size(0) : right_context.size(0) + utterance.size(0)],
+ layer_norm_input_out[right_context.size(0) + utterance.size(0) :],
+ )
+ return ffn0_out, layer_norm_input_right_context, layer_norm_input_utterance, layer_norm_input_summary
+
+ def _apply_post_attention(
+ self,
+ rc_output: torch.Tensor,
+ ffn0_out: torch.Tensor,
+ conv_cache: Optional[torch.Tensor],
+ rc_length: int,
+ utterance_length: int,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ result = self.dropout(rc_output) + ffn0_out[: rc_length + utterance_length]
+ conv_utterance, conv_right_context, conv_cache = self.conv(result[rc_length:], result[:rc_length], conv_cache)
+ result = torch.cat([conv_right_context, conv_utterance])
+ result = self.ffn1(result)
+ result = self.layer_norm_output(result)
+ output_utterance, output_right_context = result[rc_length:], result[:rc_length]
+ return output_utterance, output_right_context, conv_cache
+
+ def forward(
+ self,
+ utterance: torch.Tensor,
+ lengths: torch.Tensor,
+ right_context: torch.Tensor,
+ mems: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ r"""Forward pass for training.
+
+ B: batch size;
+ D: feature dimension of each frame;
+ T: number of utterance frames;
+ R: number of right context frames;
+ M: number of memory elements.
+
+ Args:
+ utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``utterance``.
+ right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
+ mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
+ attention_mask (torch.Tensor): attention mask for underlying attention module.
+
+ Returns:
+ (Tensor, Tensor, Tensor):
+ Tensor
+ encoded utterance frames, with shape `(T, B, D)`.
+ Tensor
+ updated right context frames, with shape `(R, B, D)`.
+ Tensor
+ updated memory elements, with shape `(M, B, D)`.
+ """
+ if self.use_mem:
+ summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
+ else:
+ summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
+
+ (
+ ffn0_out,
+ layer_norm_input_right_context,
+ layer_norm_input_utterance,
+ layer_norm_input_summary,
+ ) = self._apply_pre_attention(utterance, right_context, summary)
+
+ rc_output, output_mems = self.attention(
+ utterance=layer_norm_input_utterance,
+ lengths=lengths,
+ right_context=layer_norm_input_right_context,
+ summary=layer_norm_input_summary,
+ mems=mems,
+ attention_mask=attention_mask,
+ )
+
+ output_utterance, output_right_context, _ = self._apply_post_attention(
+ rc_output, ffn0_out, None, right_context.size(0), utterance.size(0)
+ )
+
+ return output_utterance, output_right_context, output_mems
+
+ @torch.jit.export
+ def infer(
+ self,
+ utterance: torch.Tensor,
+ lengths: torch.Tensor,
+ right_context: torch.Tensor,
+ state: Optional[List[torch.Tensor]],
+ mems: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
+ r"""Forward pass for inference.
+
+ B: batch size;
+ D: feature dimension of each frame;
+ T: number of utterance frames;
+ R: number of right context frames;
+ M: number of memory elements.
+
+ Args:
+ utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``utterance``.
+ right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
+ state (List[torch.Tensor] or None): list of tensors representing layer internal state
+ generated in preceding invocation of ``infer``.
+ mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
+
+ Returns:
+ (Tensor, Tensor, List[torch.Tensor], Tensor):
+ Tensor
+ encoded utterance frames, with shape `(T, B, D)`.
+ Tensor
+ updated right context frames, with shape `(R, B, D)`.
+ List[Tensor]
+ list of tensors representing layer internal state
+ generated in current invocation of ``infer``.
+ Tensor
+ updated memory elements, with shape `(M, B, D)`.
+ """
+ if self.use_mem:
+ summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:1]
+ else:
+ summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
+
+ (
+ ffn0_out,
+ layer_norm_input_right_context,
+ layer_norm_input_utterance,
+ layer_norm_input_summary,
+ ) = self._apply_pre_attention(utterance, right_context, summary)
+
+ if state is None:
+ state = self._init_state(layer_norm_input_utterance.size(1), device=layer_norm_input_utterance.device)
+ pre_mems, lc_key, lc_val, conv_cache = self._unpack_state(state)
+
+ rc_output, next_m, next_k, next_v = self.attention.infer(
+ utterance=layer_norm_input_utterance,
+ lengths=lengths,
+ right_context=layer_norm_input_right_context,
+ summary=layer_norm_input_summary,
+ mems=pre_mems,
+ left_context_key=lc_key,
+ left_context_val=lc_val,
+ )
+
+ output_utterance, output_right_context, conv_cache = self._apply_post_attention(
+ rc_output, ffn0_out, conv_cache, right_context.size(0), utterance.size(0)
+ )
+ output_state = self._pack_state(next_k, next_v, utterance.size(0), mems, conv_cache, state)
+ return output_utterance, output_right_context, output_state, next_m
+
+
+class ConvEmformer(_EmformerImpl):
+ r"""Implements the convolution-augmented streaming transformer architecture introduced in
+ *Streaming Transformer Transducer based Speech Recognition Using Non-Causal Convolution*
+ :cite:`9747706`.
+
+ Args:
+ input_dim (int): input dimension.
+ num_heads (int): number of attention heads in each ConvEmformer layer.
+ ffn_dim (int): hidden layer dimension of each ConvEmformer layer's feedforward network.
+ num_layers (int): number of ConvEmformer layers to instantiate.
+ segment_length (int): length of each input segment.
+ kernel_size (int): size of kernel to use in convolution modules.
+ dropout (float, optional): dropout probability. (Default: 0.0)
+ ffn_activation (str, optional): activation function to use in feedforward networks.
+ Must be one of ("relu", "gelu", "silu"). (Default: "relu")
+ left_context_length (int, optional): length of left context. (Default: 0)
+ right_context_length (int, optional): length of right context. (Default: 0)
+ max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
+ weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling
+ strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
+ tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
+ negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
+ conv_activation (str, optional): activation function to use in convolution modules.
+ Must be one of ("relu", "gelu", "silu"). (Default: "silu")
+
+ Examples:
+ >>> conv_emformer = ConvEmformer(80, 4, 1024, 12, 16, 8, right_context_length=4)
+ >>> input = torch.rand(10, 200, 80)
+ >>> lengths = torch.randint(1, 200, (10,))
+ >>> output, lengths = conv_emformer(input, lengths)
+ >>> input = torch.rand(4, 20, 80)
+ >>> lengths = torch.ones(4) * 20
+ >>> output, lengths, states = conv_emformer.infer(input, lengths, None)
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ num_heads: int,
+ ffn_dim: int,
+ num_layers: int,
+ segment_length: int,
+ kernel_size: int,
+ dropout: float = 0.0,
+ ffn_activation: str = "relu",
+ left_context_length: int = 0,
+ right_context_length: int = 0,
+ max_memory_size: int = 0,
+ weight_init_scale_strategy: Optional[str] = "depthwise",
+ tanh_on_mem: bool = False,
+ negative_inf: float = -1e8,
+ conv_activation: str = "silu",
+ ):
+ weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers)
+ emformer_layers = torch.nn.ModuleList(
+ [
+ _ConvEmformerLayer(
+ input_dim,
+ num_heads,
+ ffn_dim,
+ segment_length,
+ kernel_size,
+ dropout=dropout,
+ ffn_activation=ffn_activation,
+ left_context_length=left_context_length,
+ right_context_length=right_context_length,
+ max_memory_size=max_memory_size,
+ weight_init_gain=weight_init_gains[layer_idx],
+ tanh_on_mem=tanh_on_mem,
+ negative_inf=negative_inf,
+ conv_activation=conv_activation,
+ )
+ for layer_idx in range(num_layers)
+ ]
+ )
+ super().__init__(
+ emformer_layers,
+ segment_length,
+ left_context_length=left_context_length,
+ right_context_length=right_context_length,
+ max_memory_size=max_memory_size,
+ )
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/models/hifi_gan.py b/MLPY/Lib/site-packages/torchaudio/prototype/models/hifi_gan.py
new file mode 100644
index 0000000000000000000000000000000000000000..1db30eaec0345deba321b17bbeea331a793963e3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/models/hifi_gan.py
@@ -0,0 +1,336 @@
+"""
+MIT License
+
+Copyright (c) 2020 Jungil Kong
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+"""
+
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Conv1d, ConvTranspose1d
+
+
+class HiFiGANVocoder(torch.nn.Module):
+ """Generator part of *HiFi GAN* :cite:`NEURIPS2020_c5d73680`.
+ Source: https://github.com/jik876/hifi-gan/blob/4769534d45265d52a904b850da5a622601885777/models.py#L75
+
+ Note:
+ To build the model, please use one of the factory functions: :py:func:`hifigan_vocoder`,
+ :py:func:`hifigan_vocoder_v1`, :py:func:`hifigan_vocoder_v2`, :py:func:`hifigan_vocoder_v3`.
+
+ Args:
+ in_channels (int): Number of channels in the input features.
+ upsample_rates (tuple of ``int``): Factors by which each upsampling layer increases the time dimension.
+ upsample_initial_channel (int): Number of channels in the input feature tensor.
+ upsample_kernel_sizes (tuple of ``int``): Kernel size for each upsampling layer.
+ resblock_kernel_sizes (tuple of ``int``): Kernel size for each residual block.
+ resblock_dilation_sizes (tuple of tuples of ``int``): Dilation sizes for each 1D convolutional layer in each
+ residual block. For resblock type 1 inner tuples should have length 3, because there are 3
+ convolutions in each layer. For resblock type 2 they should have length 2.
+ resblock_type (int, 1 or 2): Determines whether ``ResBlock1`` or ``ResBlock2`` will be used.
+ lrelu_slope (float): Slope of leaky ReLUs in activations.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ upsample_rates: Tuple[int, ...],
+ upsample_initial_channel: int,
+ upsample_kernel_sizes: Tuple[int, ...],
+ resblock_kernel_sizes: Tuple[int, ...],
+ resblock_dilation_sizes: Tuple[Tuple[int, ...], ...],
+ resblock_type: int,
+ lrelu_slope: float,
+ ):
+ super(HiFiGANVocoder, self).__init__()
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+ self.conv_pre = Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
+ resblock = ResBlock1 if resblock_type == 1 else ResBlock2
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(
+ ConvTranspose1d(
+ upsample_initial_channel // (2**i),
+ upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel // (2 ** (i + 1))
+ for (k, d) in zip(resblock_kernel_sizes, resblock_dilation_sizes):
+ self.resblocks.append(resblock(ch, k, d, lrelu_slope))
+
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3)
+ self.lrelu_slope = lrelu_slope
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (Tensor): Feature input tensor of shape `(batch_size, num_channels, time_length)`.
+
+ Returns:
+ Tensor of shape `(batch_size, 1, time_length * upsample_rate)`, where `upsample_rate` is the product
+ of upsample rates for all layers.
+ """
+ x = self.conv_pre(x)
+ for i, upsampling_layer in enumerate(self.ups):
+ x = F.leaky_relu(x, self.lrelu_slope)
+ x = upsampling_layer(x)
+ xs = torch.zeros_like(x)
+ for j in range(self.num_kernels):
+ res_block: ResBlockInterface = self.resblocks[i * self.num_kernels + j]
+ xs += res_block.forward(x)
+ x = xs / self.num_kernels
+
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+
+@torch.jit.interface
+class ResBlockInterface(torch.nn.Module):
+ """Interface for ResBlock - necessary to make type annotations in ``HiFiGANVocoder.forward`` compatible
+ with TorchScript
+ """
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ pass
+
+
+class ResBlock1(torch.nn.Module):
+ """Residual block of type 1 for HiFiGAN Vocoder :cite:`NEURIPS2020_c5d73680`.
+ Args:
+ channels (int): Number of channels in the input features.
+ kernel_size (int, optional): Kernel size for 1D convolutions. (Default: ``3``)
+ dilation (tuple of 3 ``int``, optional): Dilations for each 1D convolution. (Default: ``(1, 3, 5)``)
+ lrelu_slope (float): Slope of leaky ReLUs in activations.
+ """
+
+ def __init__(
+ self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5), lrelu_slope: float = 0.1
+ ):
+ super(ResBlock1, self).__init__()
+ self.convs1 = nn.ModuleList(
+ [
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ ),
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ ),
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ ),
+ ]
+ )
+
+ self.convs2 = nn.ModuleList(
+ [
+ Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
+ Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
+ Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
+ ]
+ )
+ self.lrelu_slope = lrelu_slope
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (Tensor): input of shape ``(batch_size, channels, time_length)``.
+ Returns:
+ Tensor of the same shape as input.
+ """
+ for conv1, conv2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, self.lrelu_slope)
+ xt = conv1(xt)
+ xt = F.leaky_relu(xt, self.lrelu_slope)
+ xt = conv2(xt)
+ x = xt + x
+ return x
+
+
+class ResBlock2(torch.nn.Module):
+ """Residual block of type 2 for HiFiGAN Vocoder :cite:`NEURIPS2020_c5d73680`.
+ Args:
+ channels (int): Number of channels in the input features.
+ kernel_size (int, optional): Kernel size for 1D convolutions. (Default: ``3``)
+ dilation (tuple of 2 ``int``, optional): Dilations for each 1D convolution. (Default: ``(1, 3)``)
+ lrelu_slope (float): Slope of leaky ReLUs in activations.
+ """
+
+ def __init__(
+ self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3), lrelu_slope: float = 0.1
+ ):
+ super(ResBlock2, self).__init__()
+ self.convs = nn.ModuleList(
+ [
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ ),
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ ),
+ ]
+ )
+ self.lrelu_slope = lrelu_slope
+
+ def forward(self, x: torch.Tensor):
+ """
+ Args:
+ x (Tensor): input of shape ``(batch_size, channels, time_length)``.
+ Returns:
+ Tensor of the same shape as input.
+ """
+ for c in self.convs:
+ xt = F.leaky_relu(x, self.lrelu_slope)
+ xt = c(xt)
+ x = xt + x
+ return x
+
+
+def get_padding(kernel_size, dilation=1):
+ """Find padding for which 1D convolution preserves the input shape."""
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def hifigan_vocoder(
+ in_channels: int,
+ upsample_rates: Tuple[int, ...],
+ upsample_initial_channel: int,
+ upsample_kernel_sizes: Tuple[int, ...],
+ resblock_kernel_sizes: Tuple[int, ...],
+ resblock_dilation_sizes: Tuple[Tuple[int, ...], ...],
+ resblock_type: int,
+ lrelu_slope: float,
+) -> HiFiGANVocoder:
+ r"""Builds HiFi GAN Vocoder :cite:`NEURIPS2020_c5d73680`.
+
+ Args:
+ in_channels (int): See :py:class:`HiFiGANVocoder`.
+ upsample_rates (tuple of ``int``): See :py:class:`HiFiGANVocoder`.
+ upsample_initial_channel (int): See :py:class:`HiFiGANVocoder`.
+ upsample_kernel_sizes (tuple of ``int``): See :py:class:`HiFiGANVocoder`.
+ resblock_kernel_sizes (tuple of ``int``): See :py:class:`HiFiGANVocoder`.
+ resblock_dilation_sizes (tuple of tuples of ``int``): See :py:class:`HiFiGANVocoder`.
+ resblock_type (int, 1 or 2): See :py:class:`HiFiGANVocoder`.
+ Returns:
+ HiFiGANVocoder: generated model.
+ """
+
+ return HiFiGANVocoder(
+ upsample_rates=upsample_rates,
+ resblock_kernel_sizes=resblock_kernel_sizes,
+ resblock_dilation_sizes=resblock_dilation_sizes,
+ resblock_type=resblock_type,
+ upsample_initial_channel=upsample_initial_channel,
+ upsample_kernel_sizes=upsample_kernel_sizes,
+ in_channels=in_channels,
+ lrelu_slope=lrelu_slope,
+ )
+
+
+def hifigan_vocoder_v1() -> HiFiGANVocoder:
+ r"""Builds HiFiGAN Vocoder with V1 architecture :cite:`NEURIPS2020_c5d73680`.
+
+ Returns:
+ HiFiGANVocoder: generated model.
+ """
+ return hifigan_vocoder(
+ upsample_rates=(8, 8, 2, 2),
+ upsample_kernel_sizes=(16, 16, 4, 4),
+ upsample_initial_channel=512,
+ resblock_kernel_sizes=(3, 7, 11),
+ resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+ resblock_type=1,
+ in_channels=80,
+ lrelu_slope=0.1,
+ )
+
+
+def hifigan_vocoder_v2() -> HiFiGANVocoder:
+ r"""Builds HiFiGAN Vocoder with V2 architecture :cite:`NEURIPS2020_c5d73680`.
+
+ Returns:
+ HiFiGANVocoder: generated model.
+ """
+ return hifigan_vocoder(
+ upsample_rates=(8, 8, 2, 2),
+ upsample_kernel_sizes=(16, 16, 4, 4),
+ upsample_initial_channel=128,
+ resblock_kernel_sizes=(3, 7, 11),
+ resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+ resblock_type=1,
+ in_channels=80,
+ lrelu_slope=0.1,
+ )
+
+
+def hifigan_vocoder_v3() -> HiFiGANVocoder:
+ r"""Builds HiFiGAN Vocoder with V3 architecture :cite:`NEURIPS2020_c5d73680`.
+
+ Returns:
+ HiFiGANVocoder: generated model.
+ """
+ return hifigan_vocoder(
+ upsample_rates=(8, 8, 4),
+ upsample_kernel_sizes=(16, 16, 8),
+ upsample_initial_channel=256,
+ resblock_kernel_sizes=(3, 5, 7),
+ resblock_dilation_sizes=((1, 2), (2, 6), (3, 12)),
+ resblock_type=2,
+ in_channels=80,
+ lrelu_slope=0.1,
+ )
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/models/rnnt.py b/MLPY/Lib/site-packages/torchaudio/prototype/models/rnnt.py
new file mode 100644
index 0000000000000000000000000000000000000000..18a620f76052fb38024641d62598df062e1a94ab
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/models/rnnt.py
@@ -0,0 +1,711 @@
+import math
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from torchaudio.models import Conformer, RNNT
+from torchaudio.models.rnnt import _Joiner, _Predictor, _TimeReduction, _Transcriber
+
+
+TrieNode = Tuple[Dict[int, "TrieNode"], int, Optional[Tuple[int, int]]]
+
+
+class _ConformerEncoder(torch.nn.Module, _Transcriber):
+ def __init__(
+ self,
+ *,
+ input_dim: int,
+ output_dim: int,
+ time_reduction_stride: int,
+ conformer_input_dim: int,
+ conformer_ffn_dim: int,
+ conformer_num_layers: int,
+ conformer_num_heads: int,
+ conformer_depthwise_conv_kernel_size: int,
+ conformer_dropout: float,
+ ) -> None:
+ super().__init__()
+ self.time_reduction = _TimeReduction(time_reduction_stride)
+ self.input_linear = torch.nn.Linear(input_dim * time_reduction_stride, conformer_input_dim)
+ self.conformer = Conformer(
+ num_layers=conformer_num_layers,
+ input_dim=conformer_input_dim,
+ ffn_dim=conformer_ffn_dim,
+ num_heads=conformer_num_heads,
+ depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
+ dropout=conformer_dropout,
+ use_group_norm=True,
+ convolution_first=True,
+ )
+ self.output_linear = torch.nn.Linear(conformer_input_dim, output_dim)
+ self.layer_norm = torch.nn.LayerNorm(output_dim)
+
+ def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ time_reduction_out, time_reduction_lengths = self.time_reduction(input, lengths)
+ input_linear_out = self.input_linear(time_reduction_out)
+ x, lengths = self.conformer(input_linear_out, time_reduction_lengths)
+ output_linear_out = self.output_linear(x)
+ layer_norm_out = self.layer_norm(output_linear_out)
+ return layer_norm_out, lengths
+
+ def infer(
+ self,
+ input: torch.Tensor,
+ lengths: torch.Tensor,
+ states: Optional[List[List[torch.Tensor]]],
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
+ raise RuntimeError("Conformer does not support streaming inference.")
+
+
+class _JoinerBiasing(torch.nn.Module):
+ r"""Recurrent neural network transducer (RNN-T) joint network.
+
+ Args:
+ input_dim (int): source and target input dimension.
+ output_dim (int): output dimension.
+ activation (str, optional): activation function to use in the joiner.
+ Must be one of ("relu", "tanh"). (Default: "relu")
+ biasing (bool): perform biasing
+ deepbiasing (bool): perform deep biasing
+ attndim (int): dimension of the biasing vector hptr
+
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ activation: str = "relu",
+ biasing: bool = False,
+ deepbiasing: bool = False,
+ attndim: int = 1,
+ ) -> None:
+ super().__init__()
+ self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
+ self.biasing = biasing
+ self.deepbiasing = deepbiasing
+ if self.biasing and self.deepbiasing:
+ self.biasinglinear = torch.nn.Linear(attndim, input_dim, bias=True)
+ self.attndim = attndim
+ if activation == "relu":
+ self.activation = torch.nn.ReLU()
+ elif activation == "tanh":
+ self.activation = torch.nn.Tanh()
+ else:
+ raise ValueError(f"Unsupported activation {activation}")
+
+ def forward(
+ self,
+ source_encodings: torch.Tensor,
+ source_lengths: torch.Tensor,
+ target_encodings: torch.Tensor,
+ target_lengths: torch.Tensor,
+ hptr: torch.Tensor = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ r"""Forward pass for training.
+
+ B: batch size;
+ T: maximum source sequence length in batch;
+ U: maximum target sequence length in batch;
+ D: dimension of each source and target sequence encoding.
+
+ Args:
+ source_encodings (torch.Tensor): source encoding sequences, with
+ shape `(B, T, D)`.
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ valid sequence length of i-th batch element in ``source_encodings``.
+ target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
+ target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ valid sequence length of i-th batch element in ``target_encodings``.
+ hptr (torch.Tensor): deep biasing vector with shape `(B, T, U, A)`.
+
+ Returns:
+ (torch.Tensor, torch.Tensor, torch.Tensor):
+ torch.Tensor
+ joint network output, with shape `(B, T, U, output_dim)`.
+ torch.Tensor
+ output source lengths, with shape `(B,)` and i-th element representing
+ number of valid elements along dim 1 for i-th batch element in joint network output.
+ torch.Tensor
+ output target lengths, with shape `(B,)` and i-th element representing
+ number of valid elements along dim 2 for i-th batch element in joint network output.
+ torch.Tensor
+ joint network second last layer output (i.e. before self.linear), with shape `(B, T, U, D)`.
+ """
+ joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
+ if self.biasing and self.deepbiasing and hptr is not None:
+ hptr = self.biasinglinear(hptr)
+ joint_encodings += hptr
+ elif self.biasing and self.deepbiasing:
+ # Hack here for unused parameters
+ joint_encodings += self.biasinglinear(joint_encodings.new_zeros(1, self.attndim)).mean() * 0
+ activation_out = self.activation(joint_encodings)
+ output = self.linear(activation_out)
+ return output, source_lengths, target_lengths, activation_out
+
+
+class RNNTBiasing(RNNT):
+ r"""torchaudio.models.RNNT()
+
+ Recurrent neural network transducer (RNN-T) model.
+
+ Note:
+ To build the model, please use one of the factory functions.
+
+ Args:
+ transcriber (torch.nn.Module): transcription network.
+ predictor (torch.nn.Module): prediction network.
+ joiner (torch.nn.Module): joint network.
+ attndim (int): TCPGen attention dimension
+ biasing (bool): If true, use biasing, otherwise use standard RNN-T
+ deepbiasing (bool): If true, use deep biasing by extracting the biasing vector
+ embdim (int): dimension of symbol embeddings
+ jointdim (int): dimension of the joint network joint dimension
+ charlist (list): The list of word piece tokens in the same order as the output layer
+ encoutdim (int): dimension of the encoder output vectors
+ dropout_tcpgen (float): dropout rate for TCPGen
+ tcpsche (int): The epoch at which TCPGen starts to train
+ DBaverage (bool): If true, instead of TCPGen, use DBRNNT for biasing
+ """
+
+ def __init__(
+ self,
+ transcriber: _Transcriber,
+ predictor: _Predictor,
+ joiner: _Joiner,
+ attndim: int,
+ biasing: bool,
+ deepbiasing: bool,
+ embdim: int,
+ jointdim: int,
+ charlist: List[str],
+ encoutdim: int,
+ dropout_tcpgen: float,
+ tcpsche: int,
+ DBaverage: bool,
+ ) -> None:
+ super().__init__(transcriber, predictor, joiner)
+ self.attndim = attndim
+ self.deepbiasing = deepbiasing
+ self.jointdim = jointdim
+ self.embdim = embdim
+ self.encoutdim = encoutdim
+ self.char_list = charlist or []
+ self.blank_idx = self.char_list.index("")
+ self.nchars = len(self.char_list)
+ self.DBaverage = DBaverage
+ self.biasing = biasing
+ if self.biasing:
+ if self.deepbiasing and self.DBaverage:
+ # Deep biasing without TCPGen
+ self.biasingemb = torch.nn.Linear(self.nchars, self.attndim, bias=False)
+ else:
+ # TCPGen parameters
+ self.ooKBemb = torch.nn.Embedding(1, self.embdim)
+ self.Qproj_char = torch.nn.Linear(self.embdim, self.attndim)
+ self.Qproj_acoustic = torch.nn.Linear(self.encoutdim, self.attndim)
+ self.Kproj = torch.nn.Linear(self.embdim, self.attndim)
+ self.pointer_gate = torch.nn.Linear(self.attndim + self.jointdim, 1)
+ self.dropout_tcpgen = torch.nn.Dropout(dropout_tcpgen)
+ self.tcpsche = tcpsche
+
+ def forward(
+ self,
+ sources: torch.Tensor,
+ source_lengths: torch.Tensor,
+ targets: torch.Tensor,
+ target_lengths: torch.Tensor,
+ tries: TrieNode,
+ current_epoch: int,
+ predictor_state: Optional[List[List[torch.Tensor]]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]], torch.Tensor, torch.Tensor]:
+ r"""Forward pass for training.
+
+ B: batch size;
+ T: maximum source sequence length in batch;
+ U: maximum target sequence length in batch;
+ D: feature dimension of each source sequence element.
+
+ Args:
+ sources (torch.Tensor): source frame sequences right-padded with right context, with
+ shape `(B, T, D)`.
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``sources``.
+ targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
+ mapping to a target symbol.
+ target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ number of valid frames for i-th batch element in ``targets``.
+ tries (TrieNode): wordpiece prefix trees representing the biasing list to be searched
+ current_epoch (Int): the current epoch number to determine if TCPGen should be trained
+ at this epoch
+ predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
+ representing prediction network internal state generated in preceding invocation
+ of ``forward``. (Default: ``None``)
+
+ Returns:
+ (torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
+ torch.Tensor
+ joint network output, with shape
+ `(B, max output source length, max output target length, output_dim (number of target symbols))`.
+ torch.Tensor
+ output source lengths, with shape `(B,)` and i-th element representing
+ number of valid elements along dim 1 for i-th batch element in joint network output.
+ torch.Tensor
+ output target lengths, with shape `(B,)` and i-th element representing
+ number of valid elements along dim 2 for i-th batch element in joint network output.
+ List[List[torch.Tensor]]
+ output states; list of lists of tensors
+ representing prediction network internal state generated in current invocation
+ of ``forward``.
+ torch.Tensor
+ TCPGen distribution, with shape
+ `(B, max output source length, max output target length, output_dim (number of target symbols))`.
+ torch.Tensor
+ Generation probability (or copy probability), with shape
+ `(B, max output source length, max output target length, 1)`.
+ """
+ source_encodings, source_lengths = self.transcriber(
+ input=sources,
+ lengths=source_lengths,
+ )
+ target_encodings, target_lengths, predictor_state = self.predictor(
+ input=targets,
+ lengths=target_lengths,
+ state=predictor_state,
+ )
+ # Forward TCPGen
+ hptr = None
+ tcpgen_dist, p_gen = None, None
+ if self.biasing and current_epoch >= self.tcpsche and tries != []:
+ ptrdist_mask, p_gen_mask = self.get_tcpgen_step_masks(targets, tries)
+ hptr, tcpgen_dist = self.forward_tcpgen(targets, ptrdist_mask, source_encodings)
+ hptr = self.dropout_tcpgen(hptr)
+ elif self.biasing:
+ # Hack here to bypass unused parameters
+ if self.DBaverage and self.deepbiasing:
+ dummy = self.biasingemb(source_encodings.new_zeros(1, len(self.char_list))).mean()
+ else:
+ dummy = source_encodings.new_zeros(1, self.embdim)
+ dummy = self.Qproj_char(dummy).mean()
+ dummy += self.Qproj_acoustic(source_encodings.new_zeros(1, source_encodings.size(-1))).mean()
+ dummy += self.Kproj(source_encodings.new_zeros(1, self.embdim)).mean()
+ dummy += self.pointer_gate(source_encodings.new_zeros(1, self.attndim + self.jointdim)).mean()
+ dummy += self.ooKBemb.weight.mean()
+ dummy = dummy * 0
+ source_encodings += dummy
+
+ output, source_lengths, target_lengths, jointer_activation = self.joiner(
+ source_encodings=source_encodings,
+ source_lengths=source_lengths,
+ target_encodings=target_encodings,
+ target_lengths=target_lengths,
+ hptr=hptr,
+ )
+
+ # Calculate Generation Probability
+ if self.biasing and hptr is not None and tcpgen_dist is not None:
+ p_gen = torch.sigmoid(self.pointer_gate(torch.cat((jointer_activation, hptr), dim=-1)))
+ # avoid collapsing to ooKB token in the first few updates
+ # if current_epoch == self.tcpsche:
+ # p_gen = p_gen * 0.1
+ p_gen = p_gen.masked_fill(p_gen_mask.bool().unsqueeze(1).unsqueeze(-1), 0)
+
+ return (output, source_lengths, target_lengths, predictor_state, tcpgen_dist, p_gen)
+
+ def get_tcpgen_distribution(self, query, ptrdist_mask):
+ # Make use of the predictor embedding matrix
+ keyvalues = torch.cat([self.predictor.embedding.weight.data, self.ooKBemb.weight], dim=0)
+ keyvalues = self.dropout_tcpgen(self.Kproj(keyvalues))
+ # B * T * U * attndim, nbpe * attndim -> B * T * U * nbpe
+ tcpgendist = torch.einsum("ntuj,ij->ntui", query, keyvalues)
+ tcpgendist = tcpgendist / math.sqrt(query.size(-1))
+ ptrdist_mask = ptrdist_mask.unsqueeze(1).repeat(1, tcpgendist.size(1), 1, 1)
+ tcpgendist.masked_fill_(ptrdist_mask.bool(), -1e9)
+ tcpgendist = torch.nn.functional.softmax(tcpgendist, dim=-1)
+ # B * T * U * nbpe, nbpe * attndim -> B * T * U * attndim
+ hptr = torch.einsum("ntui,ij->ntuj", tcpgendist[:, :, :, :-1], keyvalues[:-1, :])
+ return hptr, tcpgendist
+
+ def forward_tcpgen(self, targets, ptrdist_mask, source_encodings):
+ tcpgen_dist = None
+ if self.DBaverage and self.deepbiasing:
+ hptr = self.biasingemb(1 - ptrdist_mask[:, :, :-1].float()).unsqueeze(1)
+ else:
+ query_char = self.predictor.embedding(targets)
+ query_char = self.Qproj_char(query_char).unsqueeze(1) # B * 1 * U * attndim
+ query_acoustic = self.Qproj_acoustic(source_encodings).unsqueeze(2) # B * T * 1 * attndim
+ query = query_char + query_acoustic # B * T * U * attndim
+ hptr, tcpgen_dist = self.get_tcpgen_distribution(query, ptrdist_mask)
+ return hptr, tcpgen_dist
+
+ def get_tcpgen_step_masks(self, yseqs, resettrie):
+ seqlen = len(yseqs[0])
+ batch_masks = yseqs.new_ones(len(yseqs), seqlen, len(self.char_list) + 1)
+ p_gen_masks = []
+ for i, yseq in enumerate(yseqs):
+ new_tree = resettrie
+ p_gen_mask = []
+ for j, vy in enumerate(yseq):
+ vy = vy.item()
+ new_tree = new_tree[0]
+ if vy in [self.blank_idx]:
+ new_tree = resettrie
+ p_gen_mask.append(0)
+ elif self.char_list[vy].endswith("▁"):
+ if vy in new_tree and new_tree[vy][0] != {}:
+ new_tree = new_tree[vy]
+ else:
+ new_tree = resettrie
+ p_gen_mask.append(0)
+ elif vy not in new_tree:
+ new_tree = [{}]
+ p_gen_mask.append(1)
+ else:
+ new_tree = new_tree[vy]
+ p_gen_mask.append(0)
+ batch_masks[i, j, list(new_tree[0].keys())] = 0
+ # In the original paper, ooKB node was not masked
+ # In this implementation, if not masking ooKB, ooKB probability
+ # would quickly collapse to 1.0 in the first few updates.
+ # Haven't found out why this happened.
+ # batch_masks[i, j, -1] = 0
+ p_gen_masks.append(p_gen_mask + [1] * (seqlen - len(p_gen_mask)))
+ p_gen_masks = torch.Tensor(p_gen_masks).to(yseqs.device).byte()
+ return batch_masks, p_gen_masks
+
+ def get_tcpgen_step_masks_prefix(self, yseqs, resettrie):
+ # Implemented for prefix-based wordpieces, not tested yet
+ seqlen = len(yseqs[0])
+ batch_masks = yseqs.new_ones(len(yseqs), seqlen, len(self.char_list) + 1)
+ p_gen_masks = []
+ for i, yseq in enumerate(yseqs):
+ p_gen_mask = []
+ new_tree = resettrie
+ for j, vy in enumerate(yseq):
+ vy = vy.item()
+ new_tree = new_tree[0]
+ if vy in [self.blank_idx]:
+ new_tree = resettrie
+ batch_masks[i, j, list(new_tree[0].keys())] = 0
+ elif self.char_list[vy].startswith("▁"):
+ new_tree = resettrie
+ if vy not in new_tree[0]:
+ batch_masks[i, j, list(new_tree[0].keys())] = 0
+ else:
+ new_tree = new_tree[0][vy]
+ batch_masks[i, j, list(new_tree[0].keys())] = 0
+ if new_tree[1] != -1:
+ batch_masks[i, j, list(resettrie[0].keys())] = 0
+ else:
+ if vy not in new_tree:
+ new_tree = resettrie
+ batch_masks[i, j, list(new_tree[0].keys())] = 0
+ else:
+ new_tree = new_tree[vy]
+ batch_masks[i, j, list(new_tree[0].keys())] = 0
+ if new_tree[1] != -1:
+ batch_masks[i, j, list(resettrie[0].keys())] = 0
+ p_gen_mask.append(0)
+ # batch_masks[i, j, -1] = 0
+ p_gen_masks.append(p_gen_mask + [1] * (seqlen - len(p_gen_mask)))
+ p_gen_masks = torch.Tensor(p_gen_masks).to(yseqs.device).byte()
+
+ return batch_masks, p_gen_masks
+
+ def get_tcpgen_step(self, vy, trie, resettrie):
+ new_tree = trie[0]
+ if vy in [self.blank_idx]:
+ new_tree = resettrie
+ elif self.char_list[vy].endswith("▁"):
+ if vy in new_tree and new_tree[vy][0] != {}:
+ new_tree = new_tree[vy]
+ else:
+ new_tree = resettrie
+ elif vy not in new_tree:
+ new_tree = [{}]
+ else:
+ new_tree = new_tree[vy]
+ return new_tree
+
+ def join(
+ self,
+ source_encodings: torch.Tensor,
+ source_lengths: torch.Tensor,
+ target_encodings: torch.Tensor,
+ target_lengths: torch.Tensor,
+ hptr: torch.Tensor = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ r"""Applies joint network to source and target encodings.
+
+ B: batch size;
+ T: maximum source sequence length in batch;
+ U: maximum target sequence length in batch;
+ D: dimension of each source and target sequence encoding.
+ A: TCPGen attention dimension
+
+ Args:
+ source_encodings (torch.Tensor): source encoding sequences, with
+ shape `(B, T, D)`.
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ valid sequence length of i-th batch element in ``source_encodings``.
+ target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
+ target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
+ valid sequence length of i-th batch element in ``target_encodings``.
+ hptr (torch.Tensor): deep biasing vector with shape `(B, T, U, A)`.
+
+ Returns:
+ (torch.Tensor, torch.Tensor, torch.Tensor):
+ torch.Tensor
+ joint network output, with shape `(B, T, U, output_dim)`.
+ torch.Tensor
+ output source lengths, with shape `(B,)` and i-th element representing
+ number of valid elements along dim 1 for i-th batch element in joint network output.
+ torch.Tensor
+ joint network second last layer output, with shape `(B, T, U, D)`.
+ """
+ output, source_lengths, target_lengths, jointer_activation = self.joiner(
+ source_encodings=source_encodings,
+ source_lengths=source_lengths,
+ target_encodings=target_encodings,
+ target_lengths=target_lengths,
+ hptr=hptr,
+ )
+ return output, source_lengths, jointer_activation
+
+
+def conformer_rnnt_model(
+ *,
+ input_dim: int,
+ encoding_dim: int,
+ time_reduction_stride: int,
+ conformer_input_dim: int,
+ conformer_ffn_dim: int,
+ conformer_num_layers: int,
+ conformer_num_heads: int,
+ conformer_depthwise_conv_kernel_size: int,
+ conformer_dropout: float,
+ num_symbols: int,
+ symbol_embedding_dim: int,
+ num_lstm_layers: int,
+ lstm_hidden_dim: int,
+ lstm_layer_norm: int,
+ lstm_layer_norm_epsilon: int,
+ lstm_dropout: int,
+ joiner_activation: str,
+) -> RNNT:
+ r"""Builds Conformer-based recurrent neural network transducer (RNN-T) model.
+
+ Args:
+ input_dim (int): dimension of input sequence frames passed to transcription network.
+ encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
+ passed to joint network.
+ time_reduction_stride (int): factor by which to reduce length of input sequence.
+ conformer_input_dim (int): dimension of Conformer input.
+ conformer_ffn_dim (int): hidden layer dimension of each Conformer layer's feedforward network.
+ conformer_num_layers (int): number of Conformer layers to instantiate.
+ conformer_num_heads (int): number of attention heads in each Conformer layer.
+ conformer_depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
+ conformer_dropout (float): Conformer dropout probability.
+ num_symbols (int): cardinality of set of target tokens.
+ symbol_embedding_dim (int): dimension of each target token embedding.
+ num_lstm_layers (int): number of LSTM layers to instantiate.
+ lstm_hidden_dim (int): output dimension of each LSTM layer.
+ lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
+ lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
+ lstm_dropout (float): LSTM dropout probability.
+ joiner_activation (str): activation function to use in the joiner.
+ Must be one of ("relu", "tanh"). (Default: "relu")
+
+ Returns:
+ RNNT:
+ Conformer RNN-T model.
+ """
+ encoder = _ConformerEncoder(
+ input_dim=input_dim,
+ output_dim=encoding_dim,
+ time_reduction_stride=time_reduction_stride,
+ conformer_input_dim=conformer_input_dim,
+ conformer_ffn_dim=conformer_ffn_dim,
+ conformer_num_layers=conformer_num_layers,
+ conformer_num_heads=conformer_num_heads,
+ conformer_depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
+ conformer_dropout=conformer_dropout,
+ )
+ predictor = _Predictor(
+ num_symbols=num_symbols,
+ output_dim=encoding_dim,
+ symbol_embedding_dim=symbol_embedding_dim,
+ num_lstm_layers=num_lstm_layers,
+ lstm_hidden_dim=lstm_hidden_dim,
+ lstm_layer_norm=lstm_layer_norm,
+ lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
+ lstm_dropout=lstm_dropout,
+ )
+ joiner = _Joiner(encoding_dim, num_symbols, activation=joiner_activation)
+ return RNNT(encoder, predictor, joiner)
+
+
+def conformer_rnnt_base() -> RNNT:
+ r"""Builds basic version of Conformer RNN-T model.
+
+ Returns:
+ RNNT:
+ Conformer RNN-T model.
+ """
+ return conformer_rnnt_model(
+ input_dim=80,
+ encoding_dim=1024,
+ time_reduction_stride=4,
+ conformer_input_dim=256,
+ conformer_ffn_dim=1024,
+ conformer_num_layers=16,
+ conformer_num_heads=4,
+ conformer_depthwise_conv_kernel_size=31,
+ conformer_dropout=0.1,
+ num_symbols=1024,
+ symbol_embedding_dim=256,
+ num_lstm_layers=2,
+ lstm_hidden_dim=512,
+ lstm_layer_norm=True,
+ lstm_layer_norm_epsilon=1e-5,
+ lstm_dropout=0.3,
+ joiner_activation="tanh",
+ )
+
+
+def conformer_rnnt_biasing(
+ *,
+ input_dim: int,
+ encoding_dim: int,
+ time_reduction_stride: int,
+ conformer_input_dim: int,
+ conformer_ffn_dim: int,
+ conformer_num_layers: int,
+ conformer_num_heads: int,
+ conformer_depthwise_conv_kernel_size: int,
+ conformer_dropout: float,
+ num_symbols: int,
+ symbol_embedding_dim: int,
+ num_lstm_layers: int,
+ lstm_hidden_dim: int,
+ lstm_layer_norm: int,
+ lstm_layer_norm_epsilon: int,
+ lstm_dropout: int,
+ joiner_activation: str,
+ attndim: int,
+ biasing: bool,
+ charlist: List[str],
+ deepbiasing: bool,
+ tcpsche: int,
+ DBaverage: bool,
+) -> RNNTBiasing:
+ r"""Builds Conformer-based recurrent neural network transducer (RNN-T) model.
+
+ Args:
+ input_dim (int): dimension of input sequence frames passed to transcription network.
+ encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
+ passed to joint network.
+ time_reduction_stride (int): factor by which to reduce length of input sequence.
+ conformer_input_dim (int): dimension of Conformer input.
+ conformer_ffn_dim (int): hidden layer dimension of each Conformer layer's feedforward network.
+ conformer_num_layers (int): number of Conformer layers to instantiate.
+ conformer_num_heads (int): number of attention heads in each Conformer layer.
+ conformer_depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
+ conformer_dropout (float): Conformer dropout probability.
+ num_symbols (int): cardinality of set of target tokens.
+ symbol_embedding_dim (int): dimension of each target token embedding.
+ num_lstm_layers (int): number of LSTM layers to instantiate.
+ lstm_hidden_dim (int): output dimension of each LSTM layer.
+ lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
+ lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
+ lstm_dropout (float): LSTM dropout probability.
+ joiner_activation (str): activation function to use in the joiner.
+ Must be one of ("relu", "tanh"). (Default: "relu")
+ attndim (int): TCPGen attention dimension
+ biasing (bool): If true, use biasing, otherwise use standard RNN-T
+ charlist (list): The list of word piece tokens in the same order as the output layer
+ deepbiasing (bool): If true, use deep biasing by extracting the biasing vector
+ tcpsche (int): The epoch at which TCPGen starts to train
+ DBaverage (bool): If true, instead of TCPGen, use DBRNNT for biasing
+
+ Returns:
+ RNNT:
+ Conformer RNN-T model with TCPGen-based biasing support.
+ """
+ encoder = _ConformerEncoder(
+ input_dim=input_dim,
+ output_dim=encoding_dim,
+ time_reduction_stride=time_reduction_stride,
+ conformer_input_dim=conformer_input_dim,
+ conformer_ffn_dim=conformer_ffn_dim,
+ conformer_num_layers=conformer_num_layers,
+ conformer_num_heads=conformer_num_heads,
+ conformer_depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
+ conformer_dropout=conformer_dropout,
+ )
+ predictor = _Predictor(
+ num_symbols=num_symbols,
+ output_dim=encoding_dim,
+ symbol_embedding_dim=symbol_embedding_dim,
+ num_lstm_layers=num_lstm_layers,
+ lstm_hidden_dim=lstm_hidden_dim,
+ lstm_layer_norm=lstm_layer_norm,
+ lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
+ lstm_dropout=lstm_dropout,
+ )
+ joiner = _JoinerBiasing(
+ encoding_dim,
+ num_symbols,
+ activation=joiner_activation,
+ deepbiasing=deepbiasing,
+ attndim=attndim,
+ biasing=biasing,
+ )
+ return RNNTBiasing(
+ encoder,
+ predictor,
+ joiner,
+ attndim,
+ biasing,
+ deepbiasing,
+ symbol_embedding_dim,
+ encoding_dim,
+ charlist,
+ encoding_dim,
+ conformer_dropout,
+ tcpsche,
+ DBaverage,
+ )
+
+
+def conformer_rnnt_biasing_base(charlist=None, biasing=True) -> RNNT:
+ r"""Builds basic version of Conformer RNN-T model with TCPGen.
+
+ Returns:
+ RNNT:
+ Conformer RNN-T model with TCPGen-based biasing support.
+ """
+ return conformer_rnnt_biasing(
+ input_dim=80,
+ encoding_dim=576,
+ time_reduction_stride=4,
+ conformer_input_dim=144,
+ conformer_ffn_dim=576,
+ conformer_num_layers=16,
+ conformer_num_heads=4,
+ conformer_depthwise_conv_kernel_size=31,
+ conformer_dropout=0.1,
+ num_symbols=601,
+ symbol_embedding_dim=256,
+ num_lstm_layers=1,
+ lstm_hidden_dim=320,
+ lstm_layer_norm=True,
+ lstm_layer_norm_epsilon=1e-5,
+ lstm_dropout=0.3,
+ joiner_activation="tanh",
+ attndim=256,
+ biasing=biasing,
+ charlist=charlist,
+ deepbiasing=True,
+ tcpsche=30,
+ DBaverage=False,
+ )
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/models/rnnt_decoder.py b/MLPY/Lib/site-packages/torchaudio/prototype/models/rnnt_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..129a1df27bcb8692556f3c00a58c1855c0da8ded
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/models/rnnt_decoder.py
@@ -0,0 +1,399 @@
+from typing import Callable, Dict, List, Optional, Tuple
+
+import torch
+from torchaudio.models import RNNT
+from torchaudio.prototype.models.rnnt import TrieNode
+
+__all__ = ["Hypothesis", "RNNTBeamSearchBiasing"]
+
+
+Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float, list]
+Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder,
+ represented as tuple of (tokens, prediction network output, prediction network state, score).
+ """
+
+
+def _get_hypo_tokens(hypo: Hypothesis) -> List[int]:
+ return hypo[0]
+
+
+def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor:
+ return hypo[1]
+
+
+def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]:
+ return hypo[2]
+
+
+def _get_hypo_score(hypo: Hypothesis) -> float:
+ return hypo[3]
+
+
+def _get_hypo_trie(hypo: Hypothesis) -> TrieNode:
+ return hypo[4]
+
+
+def _set_hypo_trie(hypo: Hypothesis, trie: TrieNode) -> None:
+ hypo[4] = trie
+
+
+def _get_hypo_key(hypo: Hypothesis) -> str:
+ return str(hypo[0])
+
+
+def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
+ states: List[List[torch.Tensor]] = []
+ for i in range(len(_get_hypo_state(hypos[0]))):
+ batched_state_components: List[torch.Tensor] = []
+ for j in range(len(_get_hypo_state(hypos[0])[i])):
+ batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos]))
+ states.append(batched_state_components)
+ return states
+
+
+def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]:
+ idx_tensor = torch.tensor([idx], device=device)
+ return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states]
+
+
+def _default_hypo_sort_key(hypo: Hypothesis) -> float:
+ return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1)
+
+
+def _compute_updated_scores(
+ hypos: List[Hypothesis],
+ next_token_probs: torch.Tensor,
+ beam_width: int,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1)
+ nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
+ nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
+ nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
+ nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
+ return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
+
+
+def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
+ for i, elem in enumerate(hypo_list):
+ if _get_hypo_key(hypo) == _get_hypo_key(elem):
+ del hypo_list[i]
+ break
+
+
+class RNNTBeamSearchBiasing(torch.nn.Module):
+ r"""Beam search decoder for RNN-T model with biasing support.
+
+ Args:
+ model (RNNT): RNN-T model to use.
+ blank (int): index of blank token in vocabulary.
+ temperature (float, optional): temperature to apply to joint network output.
+ Larger values yield more uniform samples. (Default: 1.0)
+ hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score
+ for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns
+ hypothesis score normalized by token sequence length. (Default: None)
+ step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100)
+ trie (list, optional): the prefix tree for TCPGen biasing
+ biasing (bool, optional): If true, do biasing, otherwise use standard RNN-T support
+ """
+
+ def __init__(
+ self,
+ model: RNNT,
+ blank: int,
+ temperature: float = 1.0,
+ hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None,
+ step_max_tokens: int = 100,
+ trie: TrieNode = None,
+ biasing: bool = False,
+ ) -> None:
+ super().__init__()
+ self.model = model
+ self.blank = blank
+ self.temperature = temperature
+ self.resettrie = trie or []
+ self.dobiasing = biasing
+
+ if hypo_sort_key is None:
+ self.hypo_sort_key = _default_hypo_sort_key
+ else:
+ self.hypo_sort_key = hypo_sort_key
+
+ self.step_max_tokens = step_max_tokens
+
+ def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]:
+ if hypo is not None:
+ token = _get_hypo_tokens(hypo)[-1]
+ state = _get_hypo_state(hypo)
+ else:
+ token = self.blank
+ state = None
+
+ one_tensor = torch.tensor([1], device=device)
+ pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
+ init_hypo = ([token], pred_out[0].detach(), pred_state, 0.0, self.resettrie)
+ return [init_hypo]
+
+ def _get_trie_mask(self, trie):
+ step_mask = torch.ones(len(self.model.char_list) + 1)
+ step_mask[list(trie[0].keys())] = 0
+ # step_mask[-1] = 0
+ return step_mask
+
+ def _get_generation_prob(self, trie):
+ if len(trie[0].keys()) == 0:
+ return True
+ else:
+ return False
+
+ def _gen_next_token_probs(
+ self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
+ ) -> torch.Tensor:
+ one_tensor = torch.tensor([1], device=device)
+ predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0)
+ if self.dobiasing:
+ # Get valid subset of wordpieces
+ trie_masks = torch.stack([self._get_trie_mask(_get_hypo_trie(h)) for h in hypos], dim=0)
+ trie_masks = trie_masks.to(enc_out.device).unsqueeze(1) # beam_width, 1, nchars
+ # Determine if there is any paths on the trie
+ genprob_masks = torch.tensor([self._get_generation_prob(_get_hypo_trie(h)) for h in hypos]) # beam_width
+ genprob_masks = genprob_masks.to(enc_out.device)
+ # Forward TCPGen component
+ last_tokens = torch.tensor([_get_hypo_tokens(h)[-1] for h in hypos]).unsqueeze(-1).to(enc_out.device)
+ hptr, tcpgen_dist = self.model.forward_tcpgen(last_tokens, trie_masks, enc_out)
+ else:
+ hptr = None
+ # hptr sent to joiner, if deepbiasing is True joiner will use it
+ joined_out, _, joined_activation = self.model.join(
+ enc_out,
+ one_tensor,
+ predictor_out,
+ torch.tensor([1] * len(hypos), device=device),
+ hptr=hptr,
+ ) # [beam_width, 1, 1, num_tokens]
+ if self.dobiasing:
+ p_gen = torch.sigmoid(self.model.pointer_gate(torch.cat((joined_activation, hptr), dim=-1)))
+ p_gen = p_gen.masked_fill(genprob_masks.view(p_gen.size(0), 1, 1, 1), 0)
+ model_tu = torch.softmax(joined_out / self.temperature, dim=3)
+ # assuming last token is blank
+ p_not_null = 1.0 - model_tu[:, :, :, -1:]
+ ptr_dist_fact = torch.cat([tcpgen_dist[:, :, :, :-2], tcpgen_dist[:, :, :, -1:]], dim=-1) * p_not_null
+ ptr_gen_complement = tcpgen_dist[:, :, :, -1:] * p_gen
+ p_partial = ptr_dist_fact[:, :, :, :-1] * p_gen + model_tu[:, :, :, :-1] * (1 - p_gen + ptr_gen_complement)
+ p_final = torch.cat([p_partial, model_tu[:, :, :, -1:]], dim=-1)
+ joined_out = torch.log(p_final)
+ else:
+ joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
+ return joined_out[:, 0, 0]
+
+ def _gen_b_hypos(
+ self,
+ b_hypos: List[Hypothesis],
+ a_hypos: List[Hypothesis],
+ next_token_probs: torch.Tensor,
+ key_to_b_hypo: Dict[str, Hypothesis],
+ ) -> List[Hypothesis]:
+ for i in range(len(a_hypos)):
+ h_a = a_hypos[i]
+ append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1]
+ if _get_hypo_key(h_a) in key_to_b_hypo:
+ h_b = key_to_b_hypo[_get_hypo_key(h_a)]
+ _remove_hypo(h_b, b_hypos)
+ score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score))
+ else:
+ score = float(append_blank_score)
+ h_b = (
+ _get_hypo_tokens(h_a),
+ _get_hypo_predictor_out(h_a),
+ _get_hypo_state(h_a),
+ score,
+ _get_hypo_trie(h_a),
+ )
+ b_hypos.append(h_b)
+ key_to_b_hypo[_get_hypo_key(h_b)] = h_b
+ _, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort()
+ return [b_hypos[idx] for idx in sorted_idx]
+
+ def _gen_a_hypos(
+ self,
+ a_hypos: List[Hypothesis],
+ b_hypos: List[Hypothesis],
+ next_token_probs: torch.Tensor,
+ t: int,
+ beam_width: int,
+ device: torch.device,
+ ) -> List[Hypothesis]:
+ (
+ nonblank_nbest_scores,
+ nonblank_nbest_hypo_idx,
+ nonblank_nbest_token,
+ ) = _compute_updated_scores(a_hypos, next_token_probs, beam_width)
+
+ if len(b_hypos) < beam_width:
+ b_nbest_score = -float("inf")
+ else:
+ b_nbest_score = _get_hypo_score(b_hypos[-beam_width])
+
+ base_hypos: List[Hypothesis] = []
+ new_tokens: List[int] = []
+ new_scores: List[float] = []
+ for i in range(beam_width):
+ score = float(nonblank_nbest_scores[i])
+ if score > b_nbest_score:
+ a_hypo_idx = int(nonblank_nbest_hypo_idx[i])
+ base_hypos.append(a_hypos[a_hypo_idx])
+ new_tokens.append(int(nonblank_nbest_token[i]))
+ new_scores.append(score)
+
+ if base_hypos:
+ new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device)
+ else:
+ new_hypos: List[Hypothesis] = []
+
+ return new_hypos
+
+ def _gen_new_hypos(
+ self,
+ base_hypos: List[Hypothesis],
+ tokens: List[int],
+ scores: List[float],
+ t: int,
+ device: torch.device,
+ ) -> List[Hypothesis]:
+ tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
+ states = _batch_state(base_hypos)
+ pred_out, _, pred_states = self.model.predict(
+ tgt_tokens,
+ torch.tensor([1] * len(base_hypos), device=device),
+ states,
+ )
+ new_hypos: List[Hypothesis] = []
+ for i, h_a in enumerate(base_hypos):
+ new_tokens = _get_hypo_tokens(h_a) + [tokens[i]]
+ if self.dobiasing:
+ new_trie = self.model.get_tcpgen_step(tokens[i], _get_hypo_trie(h_a), self.resettrie)
+ else:
+ new_trie = self.resettrie
+ new_hypos.append(
+ (new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i], new_trie)
+ )
+ return new_hypos
+
+ def _search(
+ self,
+ enc_out: torch.Tensor,
+ hypo: Optional[Hypothesis],
+ beam_width: int,
+ ) -> List[Hypothesis]:
+ n_time_steps = enc_out.shape[1]
+ device = enc_out.device
+
+ a_hypos: List[Hypothesis] = []
+ b_hypos = self._init_b_hypos(hypo, device)
+ for t in range(n_time_steps):
+ a_hypos = b_hypos
+ b_hypos = torch.jit.annotate(List[Hypothesis], [])
+ key_to_b_hypo: Dict[str, Hypothesis] = {}
+ symbols_current_t = 0
+
+ while a_hypos:
+ next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
+ next_token_probs = next_token_probs.cpu()
+ b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo)
+
+ if symbols_current_t == self.step_max_tokens:
+ break
+
+ a_hypos = self._gen_a_hypos(
+ a_hypos,
+ b_hypos,
+ next_token_probs,
+ t,
+ beam_width,
+ device,
+ )
+ if a_hypos:
+ symbols_current_t += 1
+
+ _, sorted_idx = torch.tensor([self.hypo_sort_key(hypo) for hypo in b_hypos]).topk(beam_width)
+ b_hypos = [b_hypos[idx] for idx in sorted_idx]
+
+ return b_hypos
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ length: torch.Tensor,
+ beam_width: int,
+ ) -> List[Hypothesis]:
+ r"""Performs beam search for the given input sequence.
+
+ T: number of frames;
+ D: feature dimension of each frame.
+
+ Args:
+ input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
+ length (torch.Tensor): number of valid frames in input
+ sequence, with shape () or (1,).
+ beam_width (int): beam size to use during search.
+
+ Returns:
+ List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
+ """
+ if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
+ raise ValueError("input must be of shape (T, D) or (1, T, D)")
+ if input.dim() == 2:
+ input = input.unsqueeze(0)
+
+ if length.shape != () and length.shape != (1,):
+ raise ValueError("length must be of shape () or (1,)")
+ if input.dim() == 0:
+ input = input.unsqueeze(0)
+
+ enc_out, _ = self.model.transcribe(input, length)
+ return self._search(enc_out, None, beam_width)
+
+ @torch.jit.export
+ def infer(
+ self,
+ input: torch.Tensor,
+ length: torch.Tensor,
+ beam_width: int,
+ state: Optional[List[List[torch.Tensor]]] = None,
+ hypothesis: Optional[Hypothesis] = None,
+ ) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
+ r"""Performs beam search for the given input sequence in streaming mode.
+
+ T: number of frames;
+ D: feature dimension of each frame.
+
+ Args:
+ input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
+ length (torch.Tensor): number of valid frames in input
+ sequence, with shape () or (1,).
+ beam_width (int): beam size to use during search.
+ state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
+ representing transcription network internal state generated in preceding
+ invocation. (Default: ``None``)
+ hypothesis (Hypothesis or None): hypothesis from preceding invocation to seed
+ search with. (Default: ``None``)
+
+ Returns:
+ (List[Hypothesis], List[List[torch.Tensor]]):
+ List[Hypothesis]
+ top-``beam_width`` hypotheses found by beam search.
+ List[List[torch.Tensor]]
+ list of lists of tensors representing transcription network
+ internal state generated in current invocation.
+ """
+ if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
+ raise ValueError("input must be of shape (T, D) or (1, T, D)")
+ if input.dim() == 2:
+ input = input.unsqueeze(0)
+
+ if length.shape != () and length.shape != (1,):
+ raise ValueError("length must be of shape () or (1,)")
+ if length.dim() == 0:
+ length = length.unsqueeze(0)
+
+ enc_out, _, state = self.model.transcribe_streaming(input, length, state)
+ return self._search(enc_out, hypothesis, beam_width), state
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/__init__.py b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..83da7aa43c6e387adb0cf2281cb2da70409145e4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/__init__.py
@@ -0,0 +1,12 @@
+from ._vggish import VGGISH, VGGishBundle
+from .hifigan_pipeline import HIFIGAN_VOCODER_V3_LJSPEECH, HiFiGANVocoderBundle
+from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
+
+__all__ = [
+ "EMFORMER_RNNT_BASE_MUSTC",
+ "EMFORMER_RNNT_BASE_TEDLIUM3",
+ "HIFIGAN_VOCODER_V3_LJSPEECH",
+ "HiFiGANVocoderBundle",
+ "VGGISH",
+ "VGGishBundle",
+]
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71178f893b52665d8f4602358696757fb7a3f519
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/__pycache__/hifigan_pipeline.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/__pycache__/hifigan_pipeline.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae9ffe5b47e8d58187cf4cabf629382763a893f3
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/__pycache__/hifigan_pipeline.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/__pycache__/rnnt_pipeline.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/__pycache__/rnnt_pipeline.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f32849e25266306c46595a478f0d891016dfeabd
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/__pycache__/rnnt_pipeline.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/__init__.py b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..abec68e4d4d45bcd6a74820413bf5dc6b56869f4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/__init__.py
@@ -0,0 +1,3 @@
+from ._vggish_pipeline import VGGISH, VGGishBundle
+
+__all__ = ["VGGISH", "VGGishBundle"]
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b909674c24c1af73ac2bcccf77f510d314811627
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/_vggish_impl.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/_vggish_impl.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..41aea41d2640616348d4c38fc0e689123d6c219a
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/_vggish_impl.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/_vggish_pipeline.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/_vggish_pipeline.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..881819e93c4da2e1663d6390cc5033200e086efe
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/_vggish_pipeline.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/_vggish_impl.py b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/_vggish_impl.py
new file mode 100644
index 0000000000000000000000000000000000000000..a32613720cf7f78a81d9d185a75c2e873975e6cb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/_vggish_impl.py
@@ -0,0 +1,233 @@
+# Derived from torchvggish (https://github.com/harritaylor/torchvggish).
+# Copyright 2017 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import math
+
+import torch
+
+
+_MEL_BREAK_FREQUENCY_HERTZ = 700.0
+_MEL_HIGH_FREQUENCY_Q = 1127.0
+
+
+_SAMPLE_RATE = 16000
+_STFT_WINDOW_LENGTH_SECONDS = 0.025
+_STFT_HOP_LENGTH_SECONDS = 0.010
+_MEL_MIN_HZ = 125
+_MEL_MAX_HZ = 7500
+_NUM_BANDS = 64
+_LOG_OFFSET = 0.01
+_EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
+_EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
+
+
+def _build_features_network():
+ layers = []
+
+ for input_dim, output_dim in [(1, 64), (64, 128)]:
+ layers += [
+ torch.nn.Conv2d(input_dim, output_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
+ torch.nn.ReLU(inplace=True),
+ torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
+ ]
+
+ for input_dim, output_dim in [(128, 256), (256, 512)]:
+ layers += [
+ torch.nn.Conv2d(input_dim, output_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
+ torch.nn.ReLU(inplace=True),
+ torch.nn.Conv2d(
+ output_dim,
+ output_dim,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1),
+ ),
+ torch.nn.ReLU(inplace=True),
+ torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
+ ]
+
+ return torch.nn.Sequential(*layers)
+
+
+def _build_embedding_network():
+ return torch.nn.Sequential(
+ torch.nn.Linear(512 * 4 * 6, 4096),
+ torch.nn.ReLU(True),
+ torch.nn.Linear(4096, 4096),
+ torch.nn.ReLU(True),
+ torch.nn.Linear(4096, 128),
+ torch.nn.ReLU(True),
+ )
+
+
+def _frame(data, window_length, hop_length):
+ num_samples = data.shape[0]
+ num_frames = 1 + int(math.floor((num_samples - window_length) / hop_length))
+ shape = (num_frames, window_length) + data.shape[1:]
+ strides = (data.stride()[0] * hop_length,) + data.stride()
+ return torch.as_strided(data, shape, strides)
+
+
+def _stft_magnitude(signal, fft_length, hop_length=None, window_length=None):
+ frames = _frame(signal, window_length, hop_length)
+ window = torch.hann_window(window_length, periodic=True).to(signal.device)
+ windowed_frames = frames * window
+ return torch.abs(torch.fft.rfft(windowed_frames, int(fft_length)))
+
+
+def _hertz_to_mel(frequencies_hertz):
+ return _MEL_HIGH_FREQUENCY_Q * torch.log(1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
+
+
+def _spectrogram_to_mel_matrix(
+ num_mel_bins=20,
+ num_spectrogram_bins=129,
+ audio_sample_rate=8000,
+ lower_edge_hertz=125.0,
+ upper_edge_hertz=3800.0,
+):
+ nyquist_hertz = audio_sample_rate / 2.0
+ if lower_edge_hertz < 0.0:
+ raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
+ if lower_edge_hertz >= upper_edge_hertz:
+ raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % (lower_edge_hertz, upper_edge_hertz))
+
+ if upper_edge_hertz > nyquist_hertz:
+ raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % (upper_edge_hertz, nyquist_hertz))
+ spectrogram_bins_hertz = torch.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
+
+ spectrogram_bins_mel = _hertz_to_mel(spectrogram_bins_hertz)
+ # The i'th mel band (starting from i=1) has center frequency
+ # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
+ # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
+ # the band_edges_mel arrays.
+ band_edges_mel = torch.linspace(
+ _hertz_to_mel(torch.tensor(lower_edge_hertz)),
+ _hertz_to_mel(torch.tensor(upper_edge_hertz)),
+ num_mel_bins + 2,
+ )
+ # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
+ # of spectrogram values.
+ mel_weights_matrix = torch.empty((num_spectrogram_bins, num_mel_bins))
+ for i in range(num_mel_bins):
+ lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i : i + 3]
+ # Calculate lower and upper slopes for every spectrogram bin.
+ # Line segments are linear in the *mel* domain, not hertz.
+ lower_slope = (spectrogram_bins_mel - lower_edge_mel) / (center_mel - lower_edge_mel)
+ upper_slope = (upper_edge_mel - spectrogram_bins_mel) / (upper_edge_mel - center_mel)
+
+ # .. then intersect them with each other and zero.
+ mel_weights_matrix[:, i] = torch.maximum(torch.tensor(0.0), torch.minimum(lower_slope, upper_slope))
+
+ # HTK excludes the spectrogram DC bin; make sure it always gets a zero
+ # coefficient.
+ mel_weights_matrix[0, :] = 0.0
+ return mel_weights_matrix
+
+
+def _log_mel_spectrogram(
+ data,
+ audio_sample_rate=8000,
+ log_offset=0.0,
+ window_length_secs=0.025,
+ hop_length_secs=0.010,
+ **kwargs,
+):
+ window_length_samples = int(round(audio_sample_rate * window_length_secs))
+ hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
+ fft_length = 2 ** int(math.ceil(math.log(window_length_samples) / math.log(2.0)))
+
+ spectrogram = _stft_magnitude(
+ data,
+ fft_length=fft_length,
+ hop_length=hop_length_samples,
+ window_length=window_length_samples,
+ )
+ mel_spectrogram = torch.matmul(
+ spectrogram,
+ _spectrogram_to_mel_matrix(
+ num_spectrogram_bins=spectrogram.shape[1],
+ audio_sample_rate=audio_sample_rate,
+ **kwargs,
+ ).to(spectrogram),
+ )
+ return torch.log(mel_spectrogram + log_offset)
+
+
+def _waveform_to_examples(data):
+ # Compute log mel spectrogram features, with shape (n_frame, n_mel)
+ log_mel = _log_mel_spectrogram(
+ data,
+ audio_sample_rate=_SAMPLE_RATE,
+ log_offset=_LOG_OFFSET,
+ window_length_secs=_STFT_WINDOW_LENGTH_SECONDS,
+ hop_length_secs=_STFT_HOP_LENGTH_SECONDS,
+ num_mel_bins=_NUM_BANDS,
+ lower_edge_hertz=_MEL_MIN_HZ,
+ upper_edge_hertz=_MEL_MAX_HZ,
+ )
+
+ # Frame features into examples, with shape (n_example, n_frame, n_mel)
+ features_sample_rate = 1.0 / _STFT_HOP_LENGTH_SECONDS
+ example_window_length = int(round(_EXAMPLE_WINDOW_SECONDS * features_sample_rate))
+
+ example_hop_length = int(round(_EXAMPLE_HOP_SECONDS * features_sample_rate))
+ log_mel_examples = _frame(log_mel, window_length=example_window_length, hop_length=example_hop_length)
+
+ # (n_example, 1, n_frame, n_mel)
+ return log_mel_examples.unsqueeze(1)
+
+
+class VGGish(torch.nn.Module):
+ """Implementation of VGGish model :cite:`45611`."""
+
+ def __init__(self):
+ super().__init__()
+
+ self.features_network = _build_features_network()
+ self.embedding_network = _build_embedding_network()
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ input (torch.Tensor): batch of spectrograms, with shape `(n_example, 1, n_frame, 64)`.
+
+ Returns:
+ torch.Tensor: model output, with shape `(n_example, 128)`.
+ """
+ x = self.features_network(input)
+
+ x = x.permute(0, 2, 3, 1)
+ x = x.reshape(x.size(0), -1)
+
+ return self.embedding_network(x)
+
+
+class VGGishInputProcessor:
+ """Converts raw waveforms to batches of examples to use as inputs to VGGish."""
+
+ def __call__(self, input: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ input (torch.Tensor): waveform, with shape `(T,)`.
+ sample_rate (int): sample rate of waveform in hertz.
+
+ Returns:
+ torch.Tensor: batch of examples to pass to VGGish, with shape `(n_example, 1, n_frame, 64)`.
+ """
+ if len(input.shape) != 1:
+ raise ValueError("input waveform must have dimension of 1.")
+ return _waveform_to_examples(input)
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..f67fe8ca169dcf19389391cac877056f606a6f8a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py
@@ -0,0 +1,82 @@
+from dataclasses import dataclass
+from typing import Callable, Dict
+
+import torch
+import torchaudio
+
+from ._vggish_impl import _SAMPLE_RATE, VGGish as _VGGish, VGGishInputProcessor as _VGGishInputProcessor
+
+
+def _get_state_dict():
+ path = torchaudio.utils.download_asset("models/vggish.pt")
+ return torch.load(path)
+
+
+@dataclass
+class VGGishBundle:
+ """VGGish :cite:`45611` inference pipeline ported from
+ `torchvggish `__
+ and `tensorflow-models `__.
+
+ Example:
+ >>> import torchaudio
+ >>> from torchaudio.prototype.pipelines import VGGISH
+ >>>
+ >>> input_sr = VGGISH.sample_rate
+ >>> input_proc = VGGISH.get_input_processor()
+ >>> model = VGGISH.get_model()
+ >>>
+ >>> waveform, sr = torchaudio.load(
+ >>> "Chopin_Ballade_-1_In_G_Minor,_Op._23.mp3",
+ >>> )
+ >>> waveform = waveform.squeeze(0)
+ >>> waveform = torchaudio.functional.resample(waveform, sr, input_sr)
+ >>> mono_output = model(input_proc(waveform))
+ """
+
+ class VGGish(_VGGish):
+ __doc__ = _VGGish.__doc__
+
+ class VGGishInputProcessor(_VGGishInputProcessor):
+ __doc__ = _VGGishInputProcessor.__doc__
+
+ _state_dict_func: Callable[[], Dict]
+
+ @property
+ def sample_rate(self) -> int:
+ """Sample rate of input waveform expected by input processor and model.
+
+ :type: int
+ """
+ return _SAMPLE_RATE
+
+ def get_model(self) -> VGGish:
+ """Constructs pre-trained VGGish model. Downloads and caches weights as necessary.
+
+ Returns:
+ VGGish: VGGish model with pre-trained weights loaded.
+ """
+ model = self.VGGish()
+ state_dict = self._state_dict_func()
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+ def get_input_processor(self) -> VGGishInputProcessor:
+ """Constructs input processor for VGGish.
+
+ Returns:
+ VGGishInputProcessor: input processor for VGGish.
+ """
+ return self.VGGishInputProcessor()
+
+
+VGGISH = VGGishBundle(_get_state_dict)
+VGGISH.__doc__ = """Pre-trained VGGish :cite:`45611` inference pipeline ported from
+ `torchvggish `__
+ and `tensorflow-models `__.
+
+ Per the `documentation `__
+ for the original model, the model is "trained on a large YouTube dataset (a preliminary version of
+ what later became YouTube-8M)".
+ """
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/hifigan_pipeline.py b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/hifigan_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c5a14e0731302de5bb716c902dcc9325aa42271
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/hifigan_pipeline.py
@@ -0,0 +1,228 @@
+from dataclasses import dataclass
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn.functional as F
+from torch.nn import Module
+from torchaudio._internal import load_state_dict_from_url
+
+from torchaudio.prototype.models.hifi_gan import hifigan_vocoder, HiFiGANVocoder
+from torchaudio.transforms import MelSpectrogram
+
+
+@dataclass
+class HiFiGANVocoderBundle:
+ """Data class that bundles associated information to use pretrained
+ :py:class:`~torchaudio.prototype.models.HiFiGANVocoder`.
+
+ This class provides interfaces for instantiating the pretrained model along with
+ the information necessary to retrieve pretrained weights and additional data
+ to be used with the model.
+
+ Torchaudio library instantiates objects of this class, each of which represents
+ a different pretrained model. Client code should access pretrained models via these
+ instances.
+
+ This bundle can convert mel spectrorgam to waveforms and vice versa. A typical use case would be a flow like
+ `text -> mel spectrogram -> waveform`, where one can use an external component, e.g. Tacotron2,
+ to generate mel spectrogram from text. Please see below for the code example.
+
+ Example: Transform synthetic mel spectrogram to audio.
+ >>> import torch
+ >>> import torchaudio
+ >>> # Since HiFiGAN bundle is in prototypes, it needs to be exported explicitly
+ >>> from torchaudio.prototype.pipelines import HIFIGAN_VOCODER_V3_LJSPEECH as bundle
+ >>>
+ >>> # Load the HiFiGAN bundle
+ >>> vocoder = bundle.get_vocoder()
+ Downloading: "https://download.pytorch.org/torchaudio/models/hifigan_vocoder_v3_ljspeech.pth"
+ 100%|████████████| 5.59M/5.59M [00:00<00:00, 18.7MB/s]
+ >>>
+ >>> # Generate synthetic mel spectrogram
+ >>> specgram = torch.sin(0.5 * torch.arange(start=0, end=100)).expand(bundle._vocoder_params["in_channels"], 100)
+ >>>
+ >>> # Transform mel spectrogram into audio
+ >>> waveform = vocoder(specgram)
+ >>> torchaudio.save('sample.wav', waveform, bundle.sample_rate)
+
+ Example: Usage together with Tacotron2, text to audio.
+ >>> import torch
+ >>> import torchaudio
+ >>> # Since HiFiGAN bundle is in prototypes, it needs to be exported explicitly
+ >>> from torchaudio.prototype.pipelines import HIFIGAN_VOCODER_V3_LJSPEECH as bundle_hifigan
+ >>>
+ >>> # Load Tacotron2 bundle
+ >>> bundle_tactron2 = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH
+ >>> processor = bundle_tactron2.get_text_processor()
+ >>> tacotron2 = bundle_tactron2.get_tacotron2()
+ >>>
+ >>> # Use Tacotron2 to convert text to mel spectrogram
+ >>> text = "A quick brown fox jumped over a lazy dog"
+ >>> input, lengths = processor(text)
+ >>> specgram, lengths, _ = tacotron2.infer(input, lengths)
+ >>>
+ >>> # Load HiFiGAN bundle
+ >>> vocoder = bundle_hifigan.get_vocoder()
+ Downloading: "https://download.pytorch.org/torchaudio/models/hifigan_vocoder_v3_ljspeech.pth"
+ 100%|████████████| 5.59M/5.59M [00:03<00:00, 1.55MB/s]
+ >>>
+ >>> # Use HiFiGAN to convert mel spectrogram to audio
+ >>> waveform = vocoder(specgram).squeeze(0)
+ >>> torchaudio.save('sample.wav', waveform, bundle_hifigan.sample_rate)
+ """ # noqa: E501
+
+ _path: str
+ _vocoder_params: Dict[str, Any] # Vocoder parameters
+ _mel_params: Dict[str, Any] # Mel transformation parameters
+ _sample_rate: float
+
+ def _get_state_dict(self, dl_kwargs):
+ url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
+ return state_dict
+
+ def get_vocoder(self, *, dl_kwargs=None) -> HiFiGANVocoder:
+ """Construct the HiFiGAN Generator model, which can be used a vocoder, and load the pretrained weight.
+
+ The weight file is downloaded from the internet and cached with
+ :func:`torch.hub.load_state_dict_from_url`
+
+ Args:
+ dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
+
+ Returns:
+ Variation of :py:class:`~torchaudio.prototype.models.HiFiGANVocoder`.
+ """
+ model = hifigan_vocoder(**self._vocoder_params)
+ model.load_state_dict(self._get_state_dict(dl_kwargs))
+ model.eval()
+ return model
+
+ def get_mel_transform(self) -> Module:
+ """Construct an object which transforms waveforms into mel spectrograms."""
+ return _HiFiGANMelSpectrogram(
+ n_mels=self._vocoder_params["in_channels"],
+ sample_rate=self._sample_rate,
+ **self._mel_params,
+ )
+
+ @property
+ def sample_rate(self):
+ """Sample rate of the audio that the model is trained on.
+
+ :type: float
+ """
+ return self._sample_rate
+
+
+class _HiFiGANMelSpectrogram(torch.nn.Module):
+ """
+ Generate mel spectrogram in a way equivalent to the original HiFiGAN implementation:
+ https://github.com/jik876/hifi-gan/blob/4769534d45265d52a904b850da5a622601885777/meldataset.py#L49-L72
+
+ This class wraps around :py:class:`torchaudio.transforms.MelSpectrogram`, but performs extra steps to achive
+ equivalence with the HiFiGAN implementation.
+
+ Args:
+ hop_size (int): Length of hop between STFT windows.
+ n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins.
+ win_length (int): Window size.
+ f_min (float or None): Minimum frequency.
+ f_max (float or None): Maximum frequency.
+ sample_rate (int): Sample rate of audio signal.
+ n_mels (int): Number of mel filterbanks.
+ """
+
+ def __init__(
+ self,
+ hop_size: int,
+ n_fft: int,
+ win_length: int,
+ f_min: Optional[float],
+ f_max: Optional[float],
+ sample_rate: float,
+ n_mels: int,
+ ):
+ super(_HiFiGANMelSpectrogram, self).__init__()
+ self.mel_transform = MelSpectrogram(
+ sample_rate=sample_rate,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_size,
+ f_min=f_min,
+ f_max=f_max,
+ n_mels=n_mels,
+ normalized=False,
+ pad=0,
+ mel_scale="slaney",
+ norm="slaney",
+ center=False,
+ )
+ self.sample_rate = sample_rate
+ self.hop_size = hop_size
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.f_min = f_min
+ self.f_max = f_max
+ self.n_mels = n_mels
+ self.pad_size = int((n_fft - hop_size) / 2)
+
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
+ """Generate mel spectrogram from a waveform. Should have same sample rate as ``self.sample_rate``.
+
+ Args:
+ waveform (Tensor): waveform of shape ``(batch_size, time_length)``.
+ Returns:
+ Tensor of shape ``(batch_size, n_mel, time_length)``
+ """
+ ref_waveform = F.pad(waveform.unsqueeze(1), (self.pad_size, self.pad_size), mode="reflect")
+ ref_waveform = ref_waveform.squeeze(1)
+
+ spectr = (self.mel_transform.spectrogram(ref_waveform) + 1e-9) ** 0.5
+ mel_spectrogram = self.mel_transform.mel_scale(spectr)
+ mel_spectrogram = torch.log(torch.clamp(mel_spectrogram, min=1e-5))
+ return mel_spectrogram
+
+
+HIFIGAN_VOCODER_V3_LJSPEECH = HiFiGANVocoderBundle(
+ "hifigan_vocoder_v3_ljspeech.pth",
+ _vocoder_params={
+ "upsample_rates": (8, 8, 4),
+ "upsample_kernel_sizes": (16, 16, 8),
+ "upsample_initial_channel": 256,
+ "resblock_kernel_sizes": (3, 5, 7),
+ "resblock_dilation_sizes": ((1, 2), (2, 6), (3, 12)),
+ "resblock_type": 2,
+ "in_channels": 80,
+ "lrelu_slope": 0.1,
+ },
+ _mel_params={
+ "hop_size": 256,
+ "n_fft": 1024,
+ "win_length": 1024,
+ "f_min": 0,
+ "f_max": 8000,
+ },
+ _sample_rate=22050,
+)
+HIFIGAN_VOCODER_V3_LJSPEECH.__doc__ = """HiFiGAN Vocoder pipeline, trained on *The LJ Speech Dataset*
+ :cite:`ljspeech17`.
+
+ This pipeine can be used with an external component which generates mel spectrograms from text, for example,
+ Tacotron2 - see examples in :py:class:`HiFiGANVocoderBundle`.
+ Although this works with the existing Tacotron2 bundles, for the best results one needs to retrain Tacotron2
+ using the same data preprocessing pipeline which was used for training HiFiGAN. In particular, the original
+ HiFiGAN implementation uses a custom method of generating mel spectrograms from waveforms, different from
+ :py:class:`torchaudio.transforms.MelSpectrogram`. We reimplemented this transform as
+ :py:meth:`HiFiGANVocoderBundle.get_mel_transform`, making sure it is equivalent to the original HiFiGAN code `here
+ `_.
+
+ The underlying vocoder is constructed by
+ :py:func:`torchaudio.prototype.models.hifigan_vocoder`. The weights are converted from the ones published
+ with the original paper :cite:`NEURIPS2020_c5d73680` under `MIT License
+ `__. See links to
+ pre-trained models on `GitHub `__.
+
+ Please refer to :py:class:`HiFiGANVocoderBundle` for usage instructions.
+ """
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/rnnt_pipeline.py b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/rnnt_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..20783ecdab5980252ac0f9490877b2de2e4f53a9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/pipelines/rnnt_pipeline.py
@@ -0,0 +1,58 @@
+from functools import partial
+
+from torchaudio.models import emformer_rnnt_base
+from torchaudio.pipelines import RNNTBundle
+
+
+EMFORMER_RNNT_BASE_MUSTC = RNNTBundle(
+ _rnnt_path="models/emformer_rnnt_base_mustc.pt",
+ _rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=501),
+ _global_stats_path="pipeline-assets/global_stats_rnnt_mustc.json",
+ _sp_model_path="pipeline-assets/spm_bpe_500_mustc.model",
+ _right_padding=4,
+ _blank=500,
+ _sample_rate=16000,
+ _n_fft=400,
+ _n_mels=80,
+ _hop_length=160,
+ _segment_length=16,
+ _right_context_length=4,
+)
+EMFORMER_RNNT_BASE_MUSTC.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeline capable of performing both
+streaming and non-streaming inference.
+
+The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
+and utilizes weights trained on *MuST-C release v2.0* :cite:`CATTONI2021101155` dataset
+using training script ``train.py``
+`here `__
+with ``num_symbols=501``.
+
+Please refer to :py:class:`torchaudio.pipelines.RNNTBundle` for usage instructions.
+"""
+
+
+EMFORMER_RNNT_BASE_TEDLIUM3 = RNNTBundle(
+ _rnnt_path="models/emformer_rnnt_base_tedlium3.pt",
+ _rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=501),
+ _global_stats_path="pipeline-assets/global_stats_rnnt_tedlium3.json",
+ _sp_model_path="pipeline-assets/spm_bpe_500_tedlium3.model",
+ _right_padding=4,
+ _blank=500,
+ _sample_rate=16000,
+ _n_fft=400,
+ _n_mels=80,
+ _hop_length=160,
+ _segment_length=16,
+ _right_context_length=4,
+)
+EMFORMER_RNNT_BASE_TEDLIUM3.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeline capable of performing both
+streaming and non-streaming inference.
+
+The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
+and utilizes weights trained on *TED-LIUM Release 3* :cite:`rousseau2012tedlium` dataset
+using training script ``train.py``
+`here `__
+with ``num_symbols=501``.
+
+Please refer to :py:class:`torchaudio.pipelines.RNNTBundle` for usage instructions.
+"""
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/transforms/__init__.py b/MLPY/Lib/site-packages/torchaudio/prototype/transforms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6242f3a4e7c0dec9a255ba97069de7ef52ddc957
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/transforms/__init__.py
@@ -0,0 +1,9 @@
+from ._transforms import BarkScale, BarkSpectrogram, ChromaScale, ChromaSpectrogram, InverseBarkScale
+
+__all__ = [
+ "BarkScale",
+ "BarkSpectrogram",
+ "ChromaScale",
+ "ChromaSpectrogram",
+ "InverseBarkScale",
+]
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/transforms/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/transforms/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aadf37450e4425b8b552c931f32fd65abe57b853
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/transforms/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/transforms/__pycache__/_transforms.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/prototype/transforms/__pycache__/_transforms.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e9acdbf37e41039fe1e61f34b5e026eacd736af
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/prototype/transforms/__pycache__/_transforms.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/prototype/transforms/_transforms.py b/MLPY/Lib/site-packages/torchaudio/prototype/transforms/_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0fa10824eb759f8b7e925455bdbfe2184ec7beb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/prototype/transforms/_transforms.py
@@ -0,0 +1,456 @@
+from typing import Callable, Optional
+
+import torch
+from torchaudio.prototype.functional import barkscale_fbanks, chroma_filterbank
+from torchaudio.transforms import Spectrogram
+
+
+class BarkScale(torch.nn.Module):
+ r"""Turn a normal STFT into a bark frequency STFT with triangular filter banks.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ n_barks (int, optional): Number of bark filterbanks. (Default: ``128``)
+ sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
+ f_min (float, optional): Minimum frequency. (Default: ``0.``)
+ f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
+ n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``)
+ norm (str or None, optional): If ``"slaney"``, divide the triangular bark weights by the width of the bark band
+ (area normalization). (Default: ``None``)
+ bark_scale (str, optional): Scale to use: ``traunmuller``, ``schroeder`` or ``wang``. (Default: ``traunmuller``)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> spectrogram_transform = transforms.Spectrogram(n_fft=1024)
+ >>> spectrogram = spectrogram_transform(waveform)
+ >>> barkscale_transform = transforms.BarkScale(sample_rate=sample_rate, n_stft=1024 // 2 + 1)
+ >>> barkscale_spectrogram = barkscale_transform(spectrogram)
+
+ See also:
+ :py:func:`torchaudio.prototype.functional.barkscale_fbanks` - The function used to
+ generate the filter banks.
+ """
+ __constants__ = ["n_barks", "sample_rate", "f_min", "f_max"]
+
+ def __init__(
+ self,
+ n_barks: int = 128,
+ sample_rate: int = 16000,
+ f_min: float = 0.0,
+ f_max: Optional[float] = None,
+ n_stft: int = 201,
+ bark_scale: str = "traunmuller",
+ ) -> None:
+ super(BarkScale, self).__init__()
+ self.n_barks = n_barks
+ self.sample_rate = sample_rate
+ self.f_max = f_max if f_max is not None else float(sample_rate // 2)
+ self.f_min = f_min
+ self.bark_scale = bark_scale
+
+ if f_min > self.f_max:
+ raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max))
+
+ fb = barkscale_fbanks(n_stft, self.f_min, self.f_max, self.n_barks, self.sample_rate, self.bark_scale)
+ self.register_buffer("fb", fb)
+
+ def forward(self, specgram: torch.Tensor) -> torch.Tensor:
+ r"""
+ Args:
+ specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time).
+
+ Returns:
+ torch.Tensor: Bark frequency spectrogram of size (..., ``n_barks``, time).
+ """
+
+ # (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time)
+ bark_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)
+
+ return bark_specgram
+
+
+class InverseBarkScale(torch.nn.Module):
+ r"""Estimate a STFT in normal frequency domain from bark frequency domain.
+
+ .. devices:: CPU CUDA
+
+ It minimizes the euclidian norm between the input bark-spectrogram and the product between
+ the estimated spectrogram and the filter banks using SGD.
+
+ Args:
+ n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
+ n_barks (int, optional): Number of bark filterbanks. (Default: ``128``)
+ sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
+ f_min (float, optional): Minimum frequency. (Default: ``0.``)
+ f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
+ max_iter (int, optional): Maximum number of optimization iterations. (Default: ``100000``)
+ tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
+ tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
+ sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
+ bark_scale (str, optional): Scale to use: ``traunmuller``, ``schroeder`` or ``wang``. (Default: ``traunmuller``)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> mel_spectrogram_transform = transforms.BarkSpectrogram(sample_rate, n_fft=1024)
+ >>> mel_spectrogram = bark_spectrogram_transform(waveform)
+ >>> inverse_barkscale_transform = transforms.InverseBarkScale(n_stft=1024 // 2 + 1)
+ >>> spectrogram = inverse_barkscale_transform(mel_spectrogram)
+ """
+ __constants__ = [
+ "n_stft",
+ "n_barks",
+ "sample_rate",
+ "f_min",
+ "f_max",
+ "max_iter",
+ "tolerance_loss",
+ "tolerance_change",
+ "sgdargs",
+ ]
+
+ def __init__(
+ self,
+ n_stft: int,
+ n_barks: int = 128,
+ sample_rate: int = 16000,
+ f_min: float = 0.0,
+ f_max: Optional[float] = None,
+ max_iter: int = 100000,
+ tolerance_loss: float = 1e-5,
+ tolerance_change: float = 1e-8,
+ sgdargs: Optional[dict] = None,
+ bark_scale: str = "traunmuller",
+ ) -> None:
+ super(InverseBarkScale, self).__init__()
+ self.n_barks = n_barks
+ self.sample_rate = sample_rate
+ self.f_max = f_max or float(sample_rate // 2)
+ self.f_min = f_min
+ self.max_iter = max_iter
+ self.tolerance_loss = tolerance_loss
+ self.tolerance_change = tolerance_change
+ self.sgdargs = sgdargs or {"lr": 0.1, "momentum": 0.9}
+
+ if f_min > self.f_max:
+ raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max))
+
+ fb = barkscale_fbanks(n_stft, self.f_min, self.f_max, self.n_barks, self.sample_rate, bark_scale)
+ self.register_buffer("fb", fb)
+
+ def forward(self, barkspec: torch.Tensor) -> torch.Tensor:
+ r"""
+ Args:
+ barkspec (torch.Tensor): A Bark frequency spectrogram of dimension (..., ``n_barks``, time)
+
+ Returns:
+ torch.Tensor: Linear scale spectrogram of size (..., freq, time)
+ """
+ # pack batch
+ shape = barkspec.size()
+ barkspec = barkspec.view(-1, shape[-2], shape[-1])
+
+ n_barks, time = shape[-2], shape[-1]
+ freq, _ = self.fb.size() # (freq, n_mels)
+ barkspec = barkspec.transpose(-1, -2)
+ if self.n_barks != n_barks:
+ raise ValueError("Expected an input with {} bark bins. Found: {}".format(self.n_barks, n_barks))
+
+ specgram = torch.rand(
+ barkspec.size()[0], time, freq, requires_grad=True, dtype=barkspec.dtype, device=barkspec.device
+ )
+
+ optim = torch.optim.SGD([specgram], **self.sgdargs)
+
+ loss = float("inf")
+ for _ in range(self.max_iter):
+ optim.zero_grad()
+ diff = barkspec - specgram.matmul(self.fb)
+ new_loss = diff.pow(2).sum(axis=-1).mean()
+ # take sum over bark-frequency then average over other dimensions
+ # so that loss threshold is applied par unit timeframe
+ new_loss.backward()
+ optim.step()
+ specgram.data = specgram.data.clamp(min=0)
+
+ new_loss = new_loss.item()
+ if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
+ break
+ loss = new_loss
+
+ specgram.requires_grad_(False)
+ specgram = specgram.clamp(min=0).transpose(-1, -2)
+
+ # unpack batch
+ specgram = specgram.view(shape[:-2] + (freq, time))
+ return specgram
+
+
+class BarkSpectrogram(torch.nn.Module):
+ r"""Create BarkSpectrogram for a raw audio signal.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ This is a composition of :py:func:`torchaudio.transforms.Spectrogram` and
+ and :py:func:`torchaudio.transforms.BarkScale`.
+
+ Sources
+ * https://www.fon.hum.uva.nl/praat/manual/BarkSpectrogram.html
+ * Traunmüller, Hartmut. "Analytical Expressions for the Tonotopic Sensory Scale." Journal of the Acoustical
+ * Society of America. Vol. 88, Issue 1, 1990, pp. 97–100.
+ * https://ccrma.stanford.edu/courses/120-fall-2003/lecture-5.html
+
+ Args:
+ sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
+ n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
+ win_length (int or None, optional): Window size. (Default: ``n_fft``)
+ hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
+ f_min (float, optional): Minimum frequency. (Default: ``0.``)
+ f_max (float or None, optional): Maximum frequency. (Default: ``None``)
+ pad (int, optional): Two sided padding of signal. (Default: ``0``)
+ n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
+ window_fn (Callable[..., torch.Tensor], optional): A function to create a window tensor
+ that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
+ power (float, optional): Exponent for the magnitude spectrogram,
+ (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
+ normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
+ wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
+ center (bool, optional): whether to pad :attr:`waveform` on both sides so
+ that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
+ (Default: ``True``)
+ pad_mode (string, optional): controls the padding method used when
+ :attr:`center` is ``True``. (Default: ``"reflect"``)
+ bark_scale (str, optional): Scale to use: ``traunmuller``, ``schroeder`` or ``wang``. (Default: ``traunmuller``)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = transforms.BarkSpectrogram(sample_rate)
+ >>> bark_specgram = transform(waveform) # (channel, n_barks, time)
+
+ See also:
+ :py:func:`torchaudio.functional.melscale_fbanks` - The function used to
+ generate the filter banks.
+ """
+ __constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad", "n_barks", "f_min"]
+
+ def __init__(
+ self,
+ sample_rate: int = 16000,
+ n_fft: int = 400,
+ win_length: Optional[int] = None,
+ hop_length: Optional[int] = None,
+ f_min: float = 0.0,
+ f_max: Optional[float] = None,
+ pad: int = 0,
+ n_barks: int = 128,
+ window_fn: Callable[..., torch.Tensor] = torch.hann_window,
+ power: float = 2.0,
+ normalized: bool = False,
+ wkwargs: Optional[dict] = None,
+ center: bool = True,
+ pad_mode: str = "reflect",
+ bark_scale: str = "traunmuller",
+ ) -> None:
+ super(BarkSpectrogram, self).__init__()
+
+ self.sample_rate = sample_rate
+ self.n_fft = n_fft
+ self.win_length = win_length if win_length is not None else n_fft
+ self.hop_length = hop_length if hop_length is not None else self.win_length // 2
+ self.pad = pad
+ self.power = power
+ self.normalized = normalized
+ self.n_barks = n_barks # number of bark frequency bins
+ self.f_max = f_max
+ self.f_min = f_min
+ self.spectrogram = Spectrogram(
+ n_fft=self.n_fft,
+ win_length=self.win_length,
+ hop_length=self.hop_length,
+ pad=self.pad,
+ window_fn=window_fn,
+ power=self.power,
+ normalized=self.normalized,
+ wkwargs=wkwargs,
+ center=center,
+ pad_mode=pad_mode,
+ onesided=True,
+ )
+ self.bark_scale = BarkScale(
+ self.n_barks, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1, bark_scale
+ )
+
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
+ r"""
+ Args:
+ waveform (torch.Tensor): torch.Tensor of audio of dimension (..., time).
+
+ Returns:
+ torch.Tensor: Bark frequency spectrogram of size (..., ``n_barks``, time).
+ """
+ specgram = self.spectrogram(waveform)
+ bark_specgram = self.bark_scale(specgram)
+ return bark_specgram
+
+
+class ChromaScale(torch.nn.Module):
+ r"""Converts spectrogram to chromagram.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd
+
+ Args:
+ sample_rate (int): Sample rate of audio signal.
+ n_freqs (int): Number of frequency bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
+ n_chroma (int, optional): Number of chroma. (Default: ``12``)
+ tuning (float, optional): Tuning deviation from A440 in fractions of a chroma bin. (Default: 0.0)
+ ctroct (float, optional): Center of Gaussian dominance window to weight filters by, in octaves. (Default: 5.0)
+ octwidth (float or None, optional): Width of Gaussian dominance window to weight filters by, in octaves.
+ If ``None``, then disable weighting altogether. (Default: 2.0)
+ norm (int, optional): order of norm to normalize filter bank by. (Default: 2)
+ base_c (bool, optional): If True, then start filter bank at C. Otherwise, start at A. (Default: True)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> spectrogram_transform = transforms.Spectrogram(n_fft=1024)
+ >>> spectrogram = spectrogram_transform(waveform)
+ >>> chroma_transform = transforms.ChromaScale(sample_rate=sample_rate, n_freqs=1024 // 2 + 1)
+ >>> chroma_spectrogram = chroma_transform(spectrogram)
+
+ See also:
+ :py:func:`torchaudio.prototype.functional.chroma_filterbank` — function used to
+ generate the filter bank.
+ """
+
+ def __init__(
+ self,
+ sample_rate: int,
+ n_freqs: int,
+ *,
+ n_chroma: int = 12,
+ tuning: float = 0.0,
+ ctroct: float = 5.0,
+ octwidth: Optional[float] = 2.0,
+ norm: int = 2,
+ base_c: bool = True,
+ ):
+ super().__init__()
+ fb = chroma_filterbank(
+ sample_rate, n_freqs, n_chroma, tuning=tuning, ctroct=ctroct, octwidth=octwidth, norm=norm, base_c=base_c
+ )
+ self.register_buffer("fb", fb)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ r"""
+ Args:
+ specgram (torch.Tensor): Spectrogram of dimension (..., ``n_freqs``, time).
+
+ Returns:
+ torch.Tensor: Chroma spectrogram of size (..., ``n_chroma``, time).
+ """
+ return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
+
+
+class ChromaSpectrogram(torch.nn.Module):
+ r"""Generates chromagram for audio signal.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd
+
+ Composes :py:func:`torchaudio.transforms.Spectrogram` and
+ and :py:func:`torchaudio.prototype.transforms.ChromaScale`.
+
+ Args:
+ sample_rate (int): Sample rate of audio signal.
+ n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins.
+ win_length (int or None, optional): Window size. (Default: ``n_fft``)
+ hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
+ pad (int, optional): Two sided padding of signal. (Default: ``0``)
+ window_fn (Callable[..., torch.Tensor], optional): A function to create a window tensor
+ that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
+ power (float, optional): Exponent for the magnitude spectrogram,
+ (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
+ normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
+ wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
+ center (bool, optional): whether to pad :attr:`waveform` on both sides so
+ that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
+ (Default: ``True``)
+ pad_mode (string, optional): controls the padding method used when
+ :attr:`center` is ``True``. (Default: ``"reflect"``)
+ n_chroma (int, optional): Number of chroma. (Default: ``12``)
+ tuning (float, optional): Tuning deviation from A440 in fractions of a chroma bin. (Default: 0.0)
+ ctroct (float, optional): Center of Gaussian dominance window to weight filters by, in octaves. (Default: 5.0)
+ octwidth (float or None, optional): Width of Gaussian dominance window to weight filters by, in octaves.
+ If ``None``, then disable weighting altogether. (Default: 2.0)
+ norm (int, optional): order of norm to normalize filter bank by. (Default: 2)
+ base_c (bool, optional): If True, then start filter bank at C. Otherwise, start at A. (Default: True)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = transforms.ChromaSpectrogram(sample_rate=sample_rate, n_fft=400)
+ >>> chromagram = transform(waveform) # (channel, n_chroma, time)
+ """
+
+ def __init__(
+ self,
+ sample_rate: int,
+ n_fft: int,
+ *,
+ win_length: Optional[int] = None,
+ hop_length: Optional[int] = None,
+ pad: int = 0,
+ window_fn: Callable[..., torch.Tensor] = torch.hann_window,
+ power: float = 2.0,
+ normalized: bool = False,
+ wkwargs: Optional[dict] = None,
+ center: bool = True,
+ pad_mode: str = "reflect",
+ n_chroma: int = 12,
+ tuning: float = 0.0,
+ ctroct: float = 5.0,
+ octwidth: Optional[float] = 2.0,
+ norm: int = 2,
+ base_c: bool = True,
+ ):
+ super().__init__()
+ self.spectrogram = Spectrogram(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ pad=pad,
+ window_fn=window_fn,
+ power=power,
+ normalized=normalized,
+ wkwargs=wkwargs,
+ center=center,
+ pad_mode=pad_mode,
+ onesided=True,
+ )
+ self.chroma_scale = ChromaScale(
+ sample_rate,
+ n_fft // 2 + 1,
+ n_chroma=n_chroma,
+ tuning=tuning,
+ base_c=base_c,
+ ctroct=ctroct,
+ octwidth=octwidth,
+ norm=norm,
+ )
+
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
+ r"""
+ Args:
+ waveform (Tensor): Tensor of audio of dimension (..., time).
+
+ Returns:
+ Tensor: Chromagram of size (..., ``n_chroma``, time).
+ """
+ spectrogram = self.spectrogram(waveform)
+ chroma_spectrogram = self.chroma_scale(spectrogram)
+ return chroma_spectrogram
diff --git a/MLPY/Lib/site-packages/torchaudio/sox_effects/__init__.py b/MLPY/Lib/site-packages/torchaudio/sox_effects/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c65a49b277c25c36273b888c1ac2861cb3ce9a0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/sox_effects/__init__.py
@@ -0,0 +1,10 @@
+from .sox_effects import apply_effects_file, apply_effects_tensor, effect_names, init_sox_effects, shutdown_sox_effects
+
+
+__all__ = [
+ "init_sox_effects",
+ "shutdown_sox_effects",
+ "effect_names",
+ "apply_effects_tensor",
+ "apply_effects_file",
+]
diff --git a/MLPY/Lib/site-packages/torchaudio/sox_effects/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/sox_effects/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e67975cd44829d7c43e4f8b635c2f2df1234da03
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/sox_effects/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/sox_effects/__pycache__/sox_effects.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/sox_effects/__pycache__/sox_effects.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a101c55d477fbb66b3dfbbc6f88ee382c3cd9a95
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/sox_effects/__pycache__/sox_effects.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/sox_effects/sox_effects.py b/MLPY/Lib/site-packages/torchaudio/sox_effects/sox_effects.py
new file mode 100644
index 0000000000000000000000000000000000000000..b80e4bfad689925ba4a08503ebaeddcbe5cc6a5d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/sox_effects/sox_effects.py
@@ -0,0 +1,272 @@
+import os
+from typing import List, Optional, Tuple
+
+import torch
+import torchaudio
+from torchaudio._internal.module_utils import deprecated
+from torchaudio.utils.sox_utils import list_effects
+
+
+sox_ext = torchaudio._extension.lazy_import_sox_ext()
+
+
+@deprecated("Please remove the call. This function is called automatically.")
+def init_sox_effects():
+ """Initialize resources required to use sox effects.
+
+ Note:
+ You do not need to call this function manually. It is called automatically.
+
+ Once initialized, you do not need to call this function again across the multiple uses of
+ sox effects though it is safe to do so as long as :func:`shutdown_sox_effects` is not called yet.
+ Once :func:`shutdown_sox_effects` is called, you can no longer use SoX effects and initializing
+ again will result in error.
+ """
+ pass
+
+
+@deprecated("Please remove the call. This function is called automatically.")
+def shutdown_sox_effects():
+ """Clean up resources required to use sox effects.
+
+ Note:
+ You do not need to call this function manually. It is called automatically.
+
+ It is safe to call this function multiple times.
+ Once :py:func:`shutdown_sox_effects` is called, you can no longer use SoX effects and
+ initializing again will result in error.
+ """
+ pass
+
+
+def effect_names() -> List[str]:
+ """Gets list of valid sox effect names
+
+ Returns:
+ List[str]: list of available effect names.
+
+ Example
+ >>> torchaudio.sox_effects.effect_names()
+ ['allpass', 'band', 'bandpass', ... ]
+ """
+ return list(list_effects().keys())
+
+
+def apply_effects_tensor(
+ tensor: torch.Tensor,
+ sample_rate: int,
+ effects: List[List[str]],
+ channels_first: bool = True,
+) -> Tuple[torch.Tensor, int]:
+ """Apply sox effects to given Tensor
+
+ .. devices:: CPU
+
+ .. properties:: TorchScript
+
+ Note:
+ This function only works on CPU Tensors.
+ This function works in the way very similar to ``sox`` command, however there are slight
+ differences. For example, ``sox`` command adds certain effects automatically (such as
+ ``rate`` effect after ``speed`` and ``pitch`` and other effects), but this function does
+ only applies the given effects. (Therefore, to actually apply ``speed`` effect, you also
+ need to give ``rate`` effect with desired sampling rate.).
+
+ Args:
+ tensor (torch.Tensor): Input 2D CPU Tensor.
+ sample_rate (int): Sample rate
+ effects (List[List[str]]): List of effects.
+ channels_first (bool, optional): Indicates if the input Tensor's dimension is
+ `[channels, time]` or `[time, channels]`
+
+ Returns:
+ (Tensor, int): Resulting Tensor and sample rate.
+ The resulting Tensor has the same ``dtype`` as the input Tensor, and
+ the same channels order. The shape of the Tensor can be different based on the
+ effects applied. Sample rate can also be different based on the effects applied.
+
+ Example - Basic usage
+ >>>
+ >>> # Defines the effects to apply
+ >>> effects = [
+ ... ['gain', '-n'], # normalises to 0dB
+ ... ['pitch', '5'], # 5 cent pitch shift
+ ... ['rate', '8000'], # resample to 8000 Hz
+ ... ]
+ >>>
+ >>> # Generate pseudo wave:
+ >>> # normalized, channels first, 2ch, sampling rate 16000, 1 second
+ >>> sample_rate = 16000
+ >>> waveform = 2 * torch.rand([2, sample_rate * 1]) - 1
+ >>> waveform.shape
+ torch.Size([2, 16000])
+ >>> waveform
+ tensor([[ 0.3138, 0.7620, -0.9019, ..., -0.7495, -0.4935, 0.5442],
+ [-0.0832, 0.0061, 0.8233, ..., -0.5176, -0.9140, -0.2434]])
+ >>>
+ >>> # Apply effects
+ >>> waveform, sample_rate = apply_effects_tensor(
+ ... wave_form, sample_rate, effects, channels_first=True)
+ >>>
+ >>> # Check the result
+ >>> # The new waveform is sampling rate 8000, 1 second.
+ >>> # normalization and channel order are preserved
+ >>> waveform.shape
+ torch.Size([2, 8000])
+ >>> waveform
+ tensor([[ 0.5054, -0.5518, -0.4800, ..., -0.0076, 0.0096, -0.0110],
+ [ 0.1331, 0.0436, -0.3783, ..., -0.0035, 0.0012, 0.0008]])
+ >>> sample_rate
+ 8000
+
+ Example - Torchscript-able transform
+ >>>
+ >>> # Use `apply_effects_tensor` in `torch.nn.Module` and dump it to file,
+ >>> # then run sox effect via Torchscript runtime.
+ >>>
+ >>> class SoxEffectTransform(torch.nn.Module):
+ ... effects: List[List[str]]
+ ...
+ ... def __init__(self, effects: List[List[str]]):
+ ... super().__init__()
+ ... self.effects = effects
+ ...
+ ... def forward(self, tensor: torch.Tensor, sample_rate: int):
+ ... return sox_effects.apply_effects_tensor(
+ ... tensor, sample_rate, self.effects)
+ ...
+ ...
+ >>> # Create transform object
+ >>> effects = [
+ ... ["lowpass", "-1", "300"], # apply single-pole lowpass filter
+ ... ["rate", "8000"], # change sample rate to 8000
+ ... ]
+ >>> transform = SoxEffectTensorTransform(effects, input_sample_rate)
+ >>>
+ >>> # Dump it to file and load
+ >>> path = 'sox_effect.zip'
+ >>> torch.jit.script(trans).save(path)
+ >>> transform = torch.jit.load(path)
+ >>>
+ >>>> # Run transform
+ >>> waveform, input_sample_rate = torchaudio.load("input.wav")
+ >>> waveform, sample_rate = transform(waveform, input_sample_rate)
+ >>> assert sample_rate == 8000
+ """
+ return sox_ext.apply_effects_tensor(tensor, sample_rate, effects, channels_first)
+
+
+def apply_effects_file(
+ path: str,
+ effects: List[List[str]],
+ normalize: bool = True,
+ channels_first: bool = True,
+ format: Optional[str] = None,
+) -> Tuple[torch.Tensor, int]:
+ """Apply sox effects to the audio file and load the resulting data as Tensor
+
+ .. devices:: CPU
+
+ .. properties:: TorchScript
+
+ Note:
+ This function works in the way very similar to ``sox`` command, however there are slight
+ differences. For example, ``sox`` commnad adds certain effects automatically (such as
+ ``rate`` effect after ``speed``, ``pitch`` etc), but this function only applies the given
+ effects. Therefore, to actually apply ``speed`` effect, you also need to give ``rate``
+ effect with desired sampling rate, because internally, ``speed`` effects only alter sampling
+ rate and leave samples untouched.
+
+ Args:
+ path (path-like object):
+ Source of audio data.
+ effects (List[List[str]]): List of effects.
+ normalize (bool, optional):
+ When ``True``, this function converts the native sample type to ``float32``.
+ Default: ``True``.
+
+ If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
+ integer type.
+ This argument has no effect for formats other than integer WAV type.
+
+ channels_first (bool, optional): When True, the returned Tensor has dimension `[channel, time]`.
+ Otherwise, the returned Tensor's dimension is `[time, channel]`.
+ format (str or None, optional):
+ Override the format detection with the given format.
+ Providing the argument might help when libsox can not infer the format
+ from header or extension,
+
+ Returns:
+ (Tensor, int): Resulting Tensor and sample rate.
+ If ``normalize=True``, the resulting Tensor is always ``float32`` type.
+ If ``normalize=False`` and the input audio file is of integer WAV file, then the
+ resulting Tensor has corresponding integer type. (Note 24 bit integer type is not supported)
+ If ``channels_first=True``, the resulting Tensor has dimension `[channel, time]`,
+ otherwise `[time, channel]`.
+
+ Example - Basic usage
+ >>>
+ >>> # Defines the effects to apply
+ >>> effects = [
+ ... ['gain', '-n'], # normalises to 0dB
+ ... ['pitch', '5'], # 5 cent pitch shift
+ ... ['rate', '8000'], # resample to 8000 Hz
+ ... ]
+ >>>
+ >>> # Apply effects and load data with channels_first=True
+ >>> waveform, sample_rate = apply_effects_file("data.wav", effects, channels_first=True)
+ >>>
+ >>> # Check the result
+ >>> waveform.shape
+ torch.Size([2, 8000])
+ >>> waveform
+ tensor([[ 5.1151e-03, 1.8073e-02, 2.2188e-02, ..., 1.0431e-07,
+ -1.4761e-07, 1.8114e-07],
+ [-2.6924e-03, 2.1860e-03, 1.0650e-02, ..., 6.4122e-07,
+ -5.6159e-07, 4.8103e-07]])
+ >>> sample_rate
+ 8000
+
+ Example - Apply random speed perturbation to dataset
+ >>>
+ >>> # Load data from file, apply random speed perturbation
+ >>> class RandomPerturbationFile(torch.utils.data.Dataset):
+ ... \"\"\"Given flist, apply random speed perturbation
+ ...
+ ... Suppose all the input files are at least one second long.
+ ... \"\"\"
+ ... def __init__(self, flist: List[str], sample_rate: int):
+ ... super().__init__()
+ ... self.flist = flist
+ ... self.sample_rate = sample_rate
+ ...
+ ... def __getitem__(self, index):
+ ... speed = 0.5 + 1.5 * random.randn()
+ ... effects = [
+ ... ['gain', '-n', '-10'], # apply 10 db attenuation
+ ... ['remix', '-'], # merge all the channels
+ ... ['speed', f'{speed:.5f}'], # duration is now 0.5 ~ 2.0 seconds.
+ ... ['rate', f'{self.sample_rate}'],
+ ... ['pad', '0', '1.5'], # add 1.5 seconds silence at the end
+ ... ['trim', '0', '2'], # get the first 2 seconds
+ ... ]
+ ... waveform, _ = torchaudio.sox_effects.apply_effects_file(
+ ... self.flist[index], effects)
+ ... return waveform
+ ...
+ ... def __len__(self):
+ ... return len(self.flist)
+ ...
+ >>> dataset = RandomPerturbationFile(file_list, sample_rate=8000)
+ >>> loader = torch.utils.data.DataLoader(dataset, batch_size=32)
+ >>> for batch in loader:
+ >>> pass
+ """
+ if not torch.jit.is_scripting():
+ if hasattr(path, "read"):
+ raise RuntimeError(
+ "apply_effects_file function does not support file-like object. "
+ "Please use torchaudio.io.AudioEffector."
+ )
+ path = os.fspath(path)
+ return sox_ext.apply_effects_file(path, effects, normalize, channels_first, format)
diff --git a/MLPY/Lib/site-packages/torchaudio/transforms/__init__.py b/MLPY/Lib/site-packages/torchaudio/transforms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd84516a001e93db4085229354b57a64b5213f3d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/transforms/__init__.py
@@ -0,0 +1,75 @@
+from ._multi_channel import MVDR, PSD, RTFMVDR, SoudenMVDR
+from ._transforms import (
+ AddNoise,
+ AmplitudeToDB,
+ ComputeDeltas,
+ Convolve,
+ Deemphasis,
+ Fade,
+ FFTConvolve,
+ FrequencyMasking,
+ GriffinLim,
+ InverseMelScale,
+ InverseSpectrogram,
+ LFCC,
+ Loudness,
+ MelScale,
+ MelSpectrogram,
+ MFCC,
+ MuLawDecoding,
+ MuLawEncoding,
+ PitchShift,
+ Preemphasis,
+ Resample,
+ RNNTLoss,
+ SlidingWindowCmn,
+ SpecAugment,
+ SpectralCentroid,
+ Spectrogram,
+ Speed,
+ SpeedPerturbation,
+ TimeMasking,
+ TimeStretch,
+ Vad,
+ Vol,
+)
+
+
+__all__ = [
+ "AddNoise",
+ "AmplitudeToDB",
+ "ComputeDeltas",
+ "Convolve",
+ "Deemphasis",
+ "Fade",
+ "FFTConvolve",
+ "FrequencyMasking",
+ "GriffinLim",
+ "InverseMelScale",
+ "InverseSpectrogram",
+ "LFCC",
+ "Loudness",
+ "MFCC",
+ "MVDR",
+ "MelScale",
+ "MelSpectrogram",
+ "MuLawDecoding",
+ "MuLawEncoding",
+ "PSD",
+ "PitchShift",
+ "Preemphasis",
+ "RNNTLoss",
+ "RTFMVDR",
+ "Resample",
+ "SlidingWindowCmn",
+ "SoudenMVDR",
+ "SpecAugment",
+ "SpectralCentroid",
+ "Spectrogram",
+ "Speed",
+ "SpeedPerturbation",
+ "TimeMasking",
+ "TimeStretch",
+ "Vad",
+ "Vol",
+]
diff --git a/MLPY/Lib/site-packages/torchaudio/transforms/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/transforms/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..062217144d5ae2152e7f678c8b54cbed79703450
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/transforms/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/transforms/__pycache__/_multi_channel.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/transforms/__pycache__/_multi_channel.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..131a3761e904680ce1ea18f42169bf684b1734fa
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/transforms/__pycache__/_multi_channel.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/transforms/__pycache__/_transforms.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/transforms/__pycache__/_transforms.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a88cc32a673c29eb50d7737fc14a524a4f14fd12
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/transforms/__pycache__/_transforms.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/transforms/_multi_channel.py b/MLPY/Lib/site-packages/torchaudio/transforms/_multi_channel.py
new file mode 100644
index 0000000000000000000000000000000000000000..956ccd2ee1526e56f647872a07a5f55957ce2381
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/transforms/_multi_channel.py
@@ -0,0 +1,467 @@
+# -*- coding: utf-8 -*-
+
+import warnings
+from typing import Optional, Union
+
+import torch
+from torch import Tensor
+from torchaudio import functional as F
+
+
+__all__ = []
+
+
+def _get_mvdr_vector(
+ psd_s: torch.Tensor,
+ psd_n: torch.Tensor,
+ reference_vector: torch.Tensor,
+ solution: str = "ref_channel",
+ diagonal_loading: bool = True,
+ diag_eps: float = 1e-7,
+ eps: float = 1e-8,
+) -> torch.Tensor:
+ r"""Compute the MVDR beamforming weights with ``solution`` argument.
+
+ Args:
+ psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ reference_vector (torch.Tensor): one-hot reference channel matrix.
+ solution (str, optional): Solution to compute the MVDR beamforming weights.
+ Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
+ diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
+ (Default: ``True``)
+ diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
+ It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
+ eps (float, optional): Value to add to the denominator in the beamforming weight formula.
+ (Default: ``1e-8``)
+
+ Returns:
+ torch.Tensor: the mvdr beamforming weight matrix
+ """
+ if solution == "ref_channel":
+ beamform_vector = F.mvdr_weights_souden(psd_s, psd_n, reference_vector, diagonal_loading, diag_eps, eps)
+ else:
+ if solution == "stv_evd":
+ stv = F.rtf_evd(psd_s)
+ else:
+ stv = F.rtf_power(psd_s, psd_n, reference_vector, diagonal_loading=diagonal_loading, diag_eps=diag_eps)
+ beamform_vector = F.mvdr_weights_rtf(stv, psd_n, reference_vector, diagonal_loading, diag_eps, eps)
+
+ return beamform_vector
+
+
+class PSD(torch.nn.Module):
+ r"""Compute cross-channel power spectral density (PSD) matrix.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ multi_mask (bool, optional): If ``True``, only accepts multi-channel Time-Frequency masks. (Default: ``False``)
+ normalize (bool, optional): If ``True``, normalize the mask along the time dimension. (Default: ``True``)
+ eps (float, optional): Value to add to the denominator in mask normalization. (Default: ``1e-15``)
+ """
+
+ def __init__(self, multi_mask: bool = False, normalize: bool = True, eps: float = 1e-15):
+ super().__init__()
+ self.multi_mask = multi_mask
+ self.normalize = normalize
+ self.eps = eps
+
+ def forward(self, specgram: torch.Tensor, mask: Optional[torch.Tensor] = None):
+ """
+ Args:
+ specgram (torch.Tensor): Multi-channel complex-valued spectrum.
+ Tensor with dimensions `(..., channel, freq, time)`.
+ mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
+ Tensor with dimensions `(..., freq, time)` if multi_mask is ``False`` or
+ with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
+ (Default: ``None``)
+
+ Returns:
+ torch.Tensor: The complex-valued PSD matrix of the input spectrum.
+ Tensor with dimensions `(..., freq, channel, channel)`
+ """
+ if mask is not None:
+ if self.multi_mask:
+ # Averaging mask along channel dimension
+ mask = mask.mean(dim=-3) # (..., freq, time)
+ psd = F.psd(specgram, mask, self.normalize, self.eps)
+
+ return psd
+
+
+class MVDR(torch.nn.Module):
+ """Minimum Variance Distortionless Response (MVDR) module that performs MVDR beamforming with Time-Frequency masks.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Based on https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/beamformer.py
+
+ We provide three solutions of MVDR beamforming. One is based on *reference channel selection*
+ :cite:`souden2009optimal` (``solution=ref_channel``).
+
+ .. math::
+ \\textbf{w}_{\\text{MVDR}}(f) =\
+ \\frac{{{\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f){\\bf{\\Phi}_{\\textbf{SS}}}}(f)}\
+ {\\text{Trace}({{{\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f) \\bf{\\Phi}_{\\textbf{SS}}}(f))}}\\bm{u}
+
+ where :math:`\\bf{\\Phi}_{\\textbf{SS}}` and :math:`\\bf{\\Phi}_{\\textbf{NN}}` are the covariance\
+ matrices of speech and noise, respectively. :math:`\\bf{u}` is an one-hot vector to determine the\
+ reference channel.
+
+ The other two solutions are based on the steering vector (``solution=stv_evd`` or ``solution=stv_power``).
+
+ .. math::
+ \\textbf{w}_{\\text{MVDR}}(f) =\
+ \\frac{{{\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f){\\bm{v}}(f)}}\
+ {{\\bm{v}^{\\mathsf{H}}}(f){\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f){\\bm{v}}(f)}
+
+ where :math:`\\bm{v}` is the acoustic transfer function or the steering vector.\
+ :math:`.^{\\mathsf{H}}` denotes the Hermitian Conjugate operation.
+
+ We apply either *eigenvalue decomposition*
+ :cite:`higuchi2016robust` or the *power method* :cite:`mises1929praktische` to get the
+ steering vector from the PSD matrix of speech.
+
+ After estimating the beamforming weight, the enhanced Short-time Fourier Transform (STFT) is obtained by
+
+ .. math::
+ \\hat{\\bf{S}} = {\\bf{w}^\\mathsf{H}}{\\bf{Y}}, {\\bf{w}} \\in \\mathbb{C}^{M \\times F}
+
+ where :math:`\\bf{Y}` and :math:`\\hat{\\bf{S}}` are the STFT of the multi-channel noisy speech and\
+ the single-channel enhanced speech, respectively.
+
+ For online streaming audio, we provide a *recursive method* :cite:`higuchi2017online` to update the
+ PSD matrices of speech and noise, respectively.
+
+ Args:
+ ref_channel (int, optional): Reference channel for beamforming. (Default: ``0``)
+ solution (str, optional): Solution to compute the MVDR beamforming weights.
+ Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
+ multi_mask (bool, optional): If ``True``, only accepts multi-channel Time-Frequency masks. (Default: ``False``)
+ diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to the covariance matrix
+ of the noise. (Default: ``True``)
+ diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
+ It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
+ online (bool, optional): If ``True``, updates the MVDR beamforming weights based on
+ the previous covarience matrices. (Default: ``False``)
+
+ Note:
+ To improve the numerical stability, the input spectrogram will be converted to double precision
+ (``torch.complex128`` or ``torch.cdouble``) dtype for internal computation. The output spectrogram
+ is converted to the dtype of the input spectrogram to be compatible with other modules.
+
+ Note:
+ If you use ``stv_evd`` solution, the gradient of the same input may not be identical if the
+ eigenvalues of the PSD matrix are not distinct (i.e. some eigenvalues are close or identical).
+ """
+
+ def __init__(
+ self,
+ ref_channel: int = 0,
+ solution: str = "ref_channel",
+ multi_mask: bool = False,
+ diag_loading: bool = True,
+ diag_eps: float = 1e-7,
+ online: bool = False,
+ ):
+ super().__init__()
+ if solution not in [
+ "ref_channel",
+ "stv_evd",
+ "stv_power",
+ ]:
+ raise ValueError(
+ '`solution` must be one of ["ref_channel", "stv_evd", "stv_power"]. Given {}'.format(solution)
+ )
+ self.ref_channel = ref_channel
+ self.solution = solution
+ self.multi_mask = multi_mask
+ self.diag_loading = diag_loading
+ self.diag_eps = diag_eps
+ self.online = online
+ self.psd = PSD(multi_mask)
+
+ psd_s: torch.Tensor = torch.zeros(1)
+ psd_n: torch.Tensor = torch.zeros(1)
+ mask_sum_s: torch.Tensor = torch.zeros(1)
+ mask_sum_n: torch.Tensor = torch.zeros(1)
+ self.register_buffer("psd_s", psd_s)
+ self.register_buffer("psd_n", psd_n)
+ self.register_buffer("mask_sum_s", mask_sum_s)
+ self.register_buffer("mask_sum_n", mask_sum_n)
+
+ def _get_updated_mvdr_vector(
+ self,
+ psd_s: torch.Tensor,
+ psd_n: torch.Tensor,
+ mask_s: torch.Tensor,
+ mask_n: torch.Tensor,
+ reference_vector: torch.Tensor,
+ solution: str = "ref_channel",
+ diagonal_loading: bool = True,
+ diag_eps: float = 1e-7,
+ eps: float = 1e-8,
+ ) -> torch.Tensor:
+ r"""Recursively update the MVDR beamforming vector.
+
+ Args:
+ psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ mask_s (torch.Tensor): Time-Frequency mask of the target speech.
+ Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
+ or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
+ mask_n (torch.Tensor or None, optional): Time-Frequency mask of the noise.
+ Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
+ or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
+ reference_vector (torch.Tensor): One-hot reference channel matrix.
+ solution (str, optional): Solution to compute the MVDR beamforming weights.
+ Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
+ diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
+ (Default: ``True``)
+ diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
+ It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
+ eps (float, optional): Value to add to the denominator in the beamforming weight formula.
+ (Default: ``1e-8``)
+
+ Returns:
+ torch.Tensor: The MVDR beamforming weight matrix.
+ """
+ if self.multi_mask:
+ # Averaging mask along channel dimension
+ mask_s = mask_s.mean(dim=-3) # (..., freq, time)
+ mask_n = mask_n.mean(dim=-3) # (..., freq, time)
+ if self.psd_s.ndim == 1:
+ self.psd_s = psd_s
+ self.psd_n = psd_n
+ self.mask_sum_s = mask_s.sum(dim=-1)
+ self.mask_sum_n = mask_n.sum(dim=-1)
+ return _get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)
+ else:
+ psd_s = self._get_updated_psd_speech(psd_s, mask_s)
+ psd_n = self._get_updated_psd_noise(psd_n, mask_n)
+ self.psd_s = psd_s
+ self.psd_n = psd_n
+ self.mask_sum_s = self.mask_sum_s + mask_s.sum(dim=-1)
+ self.mask_sum_n = self.mask_sum_n + mask_n.sum(dim=-1)
+ return _get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)
+
+ def _get_updated_psd_speech(self, psd_s: torch.Tensor, mask_s: torch.Tensor) -> torch.Tensor:
+ r"""Update psd of speech recursively.
+
+ Args:
+ psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ mask_s (torch.Tensor): Time-Frequency mask of the target speech.
+ Tensor with dimensions `(..., freq, time)`.
+
+ Returns:
+ torch.Tensor: The updated PSD matrix of target speech.
+ """
+ numerator = self.mask_sum_s / (self.mask_sum_s + mask_s.sum(dim=-1))
+ denominator = 1 / (self.mask_sum_s + mask_s.sum(dim=-1))
+ psd_s = self.psd_s * numerator[..., None, None] + psd_s * denominator[..., None, None]
+ return psd_s
+
+ def _get_updated_psd_noise(self, psd_n: torch.Tensor, mask_n: torch.Tensor) -> torch.Tensor:
+ r"""Update psd of noise recursively.
+
+ Args:
+ psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ mask_n (torch.Tensor or None, optional): Time-Frequency mask of the noise.
+ Tensor with dimensions `(..., freq, time)`.
+
+ Returns:
+ torch.Tensor: The updated PSD matrix of noise.
+ """
+ numerator = self.mask_sum_n / (self.mask_sum_n + mask_n.sum(dim=-1))
+ denominator = 1 / (self.mask_sum_n + mask_n.sum(dim=-1))
+ psd_n = self.psd_n * numerator[..., None, None] + psd_n * denominator[..., None, None]
+ return psd_n
+
+ def forward(
+ self, specgram: torch.Tensor, mask_s: torch.Tensor, mask_n: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Perform MVDR beamforming.
+
+ Args:
+ specgram (torch.Tensor): Multi-channel complex-valued spectrum.
+ Tensor with dimensions `(..., channel, freq, time)`
+ mask_s (torch.Tensor): Time-Frequency mask of target speech.
+ Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
+ or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
+ mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise.
+ Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
+ or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
+ (Default: None)
+
+ Returns:
+ torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
+ """
+ dtype = specgram.dtype
+ if specgram.ndim < 3:
+ raise ValueError(f"Expected at least 3D tensor (..., channel, freq, time). Found: {specgram.shape}")
+ if not specgram.is_complex():
+ raise ValueError(
+ f"The type of ``specgram`` tensor must be ``torch.cfloat`` or ``torch.cdouble``.\
+ Found: {specgram.dtype}"
+ )
+ if specgram.dtype == torch.cfloat:
+ specgram = specgram.cdouble() # Convert specgram to ``torch.cdouble``.
+
+ if mask_n is None:
+ warnings.warn("``mask_n`` is not provided, use ``1 - mask_s`` as ``mask_n``.")
+ mask_n = 1 - mask_s
+
+ psd_s = self.psd(specgram, mask_s) # (..., freq, time, channel, channel)
+ psd_n = self.psd(specgram, mask_n) # (..., freq, time, channel, channel)
+
+ u = torch.zeros(specgram.size()[:-2], device=specgram.device, dtype=torch.cdouble) # (..., channel)
+ u[..., self.ref_channel].fill_(1)
+
+ if self.online:
+ w_mvdr = self._get_updated_mvdr_vector(
+ psd_s, psd_n, mask_s, mask_n, u, self.solution, self.diag_loading, self.diag_eps
+ )
+ else:
+ w_mvdr = _get_mvdr_vector(psd_s, psd_n, u, self.solution, self.diag_loading, self.diag_eps)
+
+ specgram_enhanced = F.apply_beamforming(w_mvdr, specgram)
+
+ return specgram_enhanced.to(dtype)
+
+
+class RTFMVDR(torch.nn.Module):
+ r"""Minimum Variance Distortionless Response (*MVDR* :cite:`capon1969high`) module
+ based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the relative transfer function (RTF) matrix
+ or the steering vector of target speech :math:`\bm{v}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and
+ a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel
+ complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as:
+
+ .. math::
+ \hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
+
+ where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin,
+ :math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
+
+ The beamforming weight is computed by:
+
+ .. math::
+ \textbf{w}_{\text{MVDR}}(f) =
+ \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
+ {{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
+ """
+
+ def forward(
+ self,
+ specgram: Tensor,
+ rtf: Tensor,
+ psd_n: Tensor,
+ reference_channel: Union[int, Tensor],
+ diagonal_loading: bool = True,
+ diag_eps: float = 1e-7,
+ eps: float = 1e-8,
+ ) -> Tensor:
+ """
+ Args:
+ specgram (torch.Tensor): Multi-channel complex-valued spectrum.
+ Tensor with dimensions `(..., channel, freq, time)`
+ rtf (torch.Tensor): The complex-valued RTF vector of target speech.
+ Tensor with dimensions `(..., freq, channel)`.
+ psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ reference_channel (int or torch.Tensor): Specifies the reference channel.
+ If the dtype is ``int``, it represents the reference channel index.
+ If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
+ is one-hot.
+ diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
+ (Default: ``True``)
+ diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
+ It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
+ eps (float, optional): Value to add to the denominator in the beamforming weight formula.
+ (Default: ``1e-8``)
+
+ Returns:
+ torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
+ """
+ w_mvdr = F.mvdr_weights_rtf(rtf, psd_n, reference_channel, diagonal_loading, diag_eps, eps)
+ spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram)
+ return spectrum_enhanced
+
+
+class SoudenMVDR(torch.nn.Module):
+ r"""Minimum Variance Distortionless Response (*MVDR* :cite:`capon1969high`) module
+ based on the method proposed by *Souden et, al.* :cite:`souden2009optimal`.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the power spectral density (PSD) matrix
+ of target speech :math:`\bf{\Phi}_{\textbf{SS}}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and
+ a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel
+ complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as:
+
+ .. math::
+ \hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
+
+ where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin.
+
+ The beamforming weight is computed by:
+
+ .. math::
+ \textbf{w}_{\text{MVDR}}(f) =
+ \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}
+ {\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}
+ """
+
+ def forward(
+ self,
+ specgram: Tensor,
+ psd_s: Tensor,
+ psd_n: Tensor,
+ reference_channel: Union[int, Tensor],
+ diagonal_loading: bool = True,
+ diag_eps: float = 1e-7,
+ eps: float = 1e-8,
+ ) -> torch.Tensor:
+ """
+ Args:
+ specgram (torch.Tensor): Multi-channel complex-valued spectrum.
+ Tensor with dimensions `(..., channel, freq, time)`.
+ psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ reference_channel (int or torch.Tensor): Specifies the reference channel.
+ If the dtype is ``int``, it represents the reference channel index.
+ If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
+ is one-hot.
+ diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
+ (Default: ``True``)
+ diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
+ It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
+ eps (float, optional): Value to add to the denominator in the beamforming weight formula.
+ (Default: ``1e-8``)
+
+ Returns:
+ torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
+ """
+ w_mvdr = F.mvdr_weights_souden(psd_s, psd_n, reference_channel, diagonal_loading, diag_eps, eps)
+ spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram)
+ return spectrum_enhanced
diff --git a/MLPY/Lib/site-packages/torchaudio/transforms/_transforms.py b/MLPY/Lib/site-packages/torchaudio/transforms/_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..242fd971a8d1efabb485a4cdbda2a8f1dbf59f02
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/transforms/_transforms.py
@@ -0,0 +1,2137 @@
+# -*- coding: utf-8 -*-
+
+import math
+import warnings
+from typing import Callable, Optional, Sequence, Tuple, Union
+
+import torch
+from torch import Tensor
+from torch.nn.modules.lazy import LazyModuleMixin
+from torch.nn.parameter import UninitializedParameter
+
+from torchaudio import functional as F
+from torchaudio.functional.functional import (
+ _apply_sinc_resample_kernel,
+ _check_convolve_mode,
+ _fix_waveform_shape,
+ _get_sinc_resample_kernel,
+ _stretch_waveform,
+)
+
+__all__ = []
+
+
+class Spectrogram(torch.nn.Module):
+ r"""Create a spectrogram from a audio signal.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
+ win_length (int or None, optional): Window size. (Default: ``n_fft``)
+ hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
+ pad (int, optional): Two sided padding of signal. (Default: ``0``)
+ window_fn (Callable[..., Tensor], optional): A function to create a window tensor
+ that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
+ power (float or None, optional): Exponent for the magnitude spectrogram,
+ (must be > 0) e.g., 1 for magnitude, 2 for power, etc.
+ If None, then the complex spectrum is returned instead. (Default: ``2``)
+ normalized (bool or str, optional): Whether to normalize by magnitude after stft. If input is str, choices are
+ ``"window"`` and ``"frame_length"``, if specific normalization type is desirable. ``True`` maps to
+ ``"window"``. (Default: ``False``)
+ wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
+ center (bool, optional): whether to pad :attr:`waveform` on both sides so
+ that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
+ (Default: ``True``)
+ pad_mode (string, optional): controls the padding method used when
+ :attr:`center` is ``True``. (Default: ``"reflect"``)
+ onesided (bool, optional): controls whether to return half of results to
+ avoid redundancy (Default: ``True``)
+ return_complex (bool, optional):
+ Deprecated and not used.
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = torchaudio.transforms.Spectrogram(n_fft=800)
+ >>> spectrogram = transform(waveform)
+
+ """
+ __constants__ = ["n_fft", "win_length", "hop_length", "pad", "power", "normalized"]
+
+ def __init__(
+ self,
+ n_fft: int = 400,
+ win_length: Optional[int] = None,
+ hop_length: Optional[int] = None,
+ pad: int = 0,
+ window_fn: Callable[..., Tensor] = torch.hann_window,
+ power: Optional[float] = 2.0,
+ normalized: Union[bool, str] = False,
+ wkwargs: Optional[dict] = None,
+ center: bool = True,
+ pad_mode: str = "reflect",
+ onesided: bool = True,
+ return_complex: Optional[bool] = None,
+ ) -> None:
+ super(Spectrogram, self).__init__()
+ torch._C._log_api_usage_once("torchaudio.transforms.Spectrogram")
+ self.n_fft = n_fft
+ # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
+ # number of frequencies due to onesided=True in torch.stft
+ self.win_length = win_length if win_length is not None else n_fft
+ self.hop_length = hop_length if hop_length is not None else self.win_length // 2
+ window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
+ self.register_buffer("window", window)
+ self.pad = pad
+ self.power = power
+ self.normalized = normalized
+ self.center = center
+ self.pad_mode = pad_mode
+ self.onesided = onesided
+ if return_complex is not None:
+ warnings.warn(
+ "`return_complex` argument is now deprecated and is not effective."
+ "`torchaudio.transforms.Spectrogram(power=None)` always returns a tensor with "
+ "complex dtype. Please remove the argument in the function call."
+ )
+
+ def forward(self, waveform: Tensor) -> Tensor:
+ r"""
+ Args:
+ waveform (Tensor): Tensor of audio of dimension (..., time).
+
+ Returns:
+ Tensor: Dimension (..., freq, time), where freq is
+ ``n_fft // 2 + 1`` where ``n_fft`` is the number of
+ Fourier bins, and time is the number of window hops (n_frame).
+ """
+ return F.spectrogram(
+ waveform,
+ self.pad,
+ self.window,
+ self.n_fft,
+ self.hop_length,
+ self.win_length,
+ self.power,
+ self.normalized,
+ self.center,
+ self.pad_mode,
+ self.onesided,
+ )
+
+
+class InverseSpectrogram(torch.nn.Module):
+ r"""Create an inverse spectrogram to recover an audio signal from a spectrogram.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
+ win_length (int or None, optional): Window size. (Default: ``n_fft``)
+ hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
+ pad (int, optional): Two sided padding of signal. (Default: ``0``)
+ window_fn (Callable[..., Tensor], optional): A function to create a window tensor
+ that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
+ normalized (bool or str, optional): Whether the stft output was normalized by magnitude. If input is str,
+ choices are ``"window"`` and ``"frame_length"``, dependent on normalization mode. ``True`` maps to
+ ``"window"``. (Default: ``False``)
+ wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
+ center (bool, optional): whether the signal in spectrogram was padded on both sides so
+ that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
+ (Default: ``True``)
+ pad_mode (string, optional): controls the padding method used when
+ :attr:`center` is ``True``. (Default: ``"reflect"``)
+ onesided (bool, optional): controls whether spectrogram was used to return half of results to
+ avoid redundancy (Default: ``True``)
+
+ Example
+ >>> batch, freq, time = 2, 257, 100
+ >>> length = 25344
+ >>> spectrogram = torch.randn(batch, freq, time, dtype=torch.cdouble)
+ >>> transform = transforms.InverseSpectrogram(n_fft=512)
+ >>> waveform = transform(spectrogram, length)
+ """
+ __constants__ = ["n_fft", "win_length", "hop_length", "pad", "power", "normalized"]
+
+ def __init__(
+ self,
+ n_fft: int = 400,
+ win_length: Optional[int] = None,
+ hop_length: Optional[int] = None,
+ pad: int = 0,
+ window_fn: Callable[..., Tensor] = torch.hann_window,
+ normalized: Union[bool, str] = False,
+ wkwargs: Optional[dict] = None,
+ center: bool = True,
+ pad_mode: str = "reflect",
+ onesided: bool = True,
+ ) -> None:
+ super(InverseSpectrogram, self).__init__()
+ self.n_fft = n_fft
+ # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
+ # number of frequencies due to onesided=True in torch.stft
+ self.win_length = win_length if win_length is not None else n_fft
+ self.hop_length = hop_length if hop_length is not None else self.win_length // 2
+ window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
+ self.register_buffer("window", window)
+ self.pad = pad
+ self.normalized = normalized
+ self.center = center
+ self.pad_mode = pad_mode
+ self.onesided = onesided
+
+ def forward(self, spectrogram: Tensor, length: Optional[int] = None) -> Tensor:
+ r"""
+ Args:
+ spectrogram (Tensor): Complex tensor of audio of dimension (..., freq, time).
+ length (int or None, optional): The output length of the waveform.
+
+ Returns:
+ Tensor: Dimension (..., time), Least squares estimation of the original signal.
+ """
+ return F.inverse_spectrogram(
+ spectrogram,
+ length,
+ self.pad,
+ self.window,
+ self.n_fft,
+ self.hop_length,
+ self.win_length,
+ self.normalized,
+ self.center,
+ self.pad_mode,
+ self.onesided,
+ )
+
+
+class GriffinLim(torch.nn.Module):
+ r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Implementation ported from
+ *librosa* :cite:`brian_mcfee-proc-scipy-2015`, *A fast Griffin-Lim algorithm* :cite:`6701851`
+ and *Signal estimation from modified short-time Fourier transform* :cite:`1172092`.
+
+ Args:
+ n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
+ n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``)
+ win_length (int or None, optional): Window size. (Default: ``n_fft``)
+ hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
+ window_fn (Callable[..., Tensor], optional): A function to create a window tensor
+ that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
+ power (float, optional): Exponent for the magnitude spectrogram,
+ (must be > 0) e.g., 1 for magnitude, 2 for power, etc. (Default: ``2``)
+ wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
+ momentum (float, optional): The momentum parameter for fast Griffin-Lim.
+ Setting this to 0 recovers the original Griffin-Lim method.
+ Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``)
+ length (int, optional): Array length of the expected output. (Default: ``None``)
+ rand_init (bool, optional): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
+
+ Example
+ >>> batch, freq, time = 2, 257, 100
+ >>> spectrogram = torch.randn(batch, freq, time)
+ >>> transform = transforms.GriffinLim(n_fft=512)
+ >>> waveform = transform(spectrogram)
+ """
+ __constants__ = ["n_fft", "n_iter", "win_length", "hop_length", "power", "length", "momentum", "rand_init"]
+
+ def __init__(
+ self,
+ n_fft: int = 400,
+ n_iter: int = 32,
+ win_length: Optional[int] = None,
+ hop_length: Optional[int] = None,
+ window_fn: Callable[..., Tensor] = torch.hann_window,
+ power: float = 2.0,
+ wkwargs: Optional[dict] = None,
+ momentum: float = 0.99,
+ length: Optional[int] = None,
+ rand_init: bool = True,
+ ) -> None:
+ super(GriffinLim, self).__init__()
+
+ if not (0 <= momentum < 1):
+ raise ValueError("momentum must be in the range [0, 1). Found: {}".format(momentum))
+
+ self.n_fft = n_fft
+ self.n_iter = n_iter
+ self.win_length = win_length if win_length is not None else n_fft
+ self.hop_length = hop_length if hop_length is not None else self.win_length // 2
+ window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
+ self.register_buffer("window", window)
+ self.length = length
+ self.power = power
+ self.momentum = momentum
+ self.rand_init = rand_init
+
+ def forward(self, specgram: Tensor) -> Tensor:
+ r"""
+ Args:
+ specgram (Tensor):
+ A magnitude-only STFT spectrogram of dimension (..., freq, frames)
+ where freq is ``n_fft // 2 + 1``.
+
+ Returns:
+ Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
+ """
+ return F.griffinlim(
+ specgram,
+ self.window,
+ self.n_fft,
+ self.hop_length,
+ self.win_length,
+ self.power,
+ self.n_iter,
+ self.momentum,
+ self.length,
+ self.rand_init,
+ )
+
+
+class AmplitudeToDB(torch.nn.Module):
+ r"""Turn a tensor from the power/amplitude scale to the decibel scale.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ This output depends on the maximum value in the input tensor, and so
+ may return different values for an audio clip split into snippets vs. a
+ a full clip.
+
+ Args:
+ stype (str, optional): scale of input tensor (``"power"`` or ``"magnitude"``). The
+ power being the elementwise square of the magnitude. (Default: ``"power"``)
+ top_db (float or None, optional): minimum negative cut-off in decibels. A reasonable
+ number is 80. (Default: ``None``)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = transforms.AmplitudeToDB(stype="amplitude", top_db=80)
+ >>> waveform_db = transform(waveform)
+ """
+ __constants__ = ["multiplier", "amin", "ref_value", "db_multiplier"]
+
+ def __init__(self, stype: str = "power", top_db: Optional[float] = None) -> None:
+ super(AmplitudeToDB, self).__init__()
+ self.stype = stype
+ if top_db is not None and top_db < 0:
+ raise ValueError("top_db must be positive value")
+ self.top_db = top_db
+ self.multiplier = 10.0 if stype == "power" else 20.0
+ self.amin = 1e-10
+ self.ref_value = 1.0
+ self.db_multiplier = math.log10(max(self.amin, self.ref_value))
+
+ def forward(self, x: Tensor) -> Tensor:
+ r"""Numerically stable implementation from Librosa.
+
+ https://librosa.org/doc/latest/generated/librosa.amplitude_to_db.html
+
+ Args:
+ x (Tensor): Input tensor before being converted to decibel scale.
+
+ Returns:
+ Tensor: Output tensor in decibel scale.
+ """
+ return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
+
+
+class MelScale(torch.nn.Module):
+ r"""Turn a normal STFT into a mel frequency STFT with triangular filter banks.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
+ sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
+ f_min (float, optional): Minimum frequency. (Default: ``0.``)
+ f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
+ n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``)
+ norm (str or None, optional): If ``"slaney"``, divide the triangular mel weights by the width of the mel band
+ (area normalization). (Default: ``None``)
+ mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> spectrogram_transform = transforms.Spectrogram(n_fft=1024)
+ >>> spectrogram = spectrogram_transform(waveform)
+ >>> melscale_transform = transforms.MelScale(sample_rate=sample_rate, n_stft=1024 // 2 + 1)
+ >>> melscale_spectrogram = melscale_transform(spectrogram)
+
+ See also:
+ :py:func:`torchaudio.functional.melscale_fbanks` - The function used to
+ generate the filter banks.
+ """
+ __constants__ = ["n_mels", "sample_rate", "f_min", "f_max"]
+
+ def __init__(
+ self,
+ n_mels: int = 128,
+ sample_rate: int = 16000,
+ f_min: float = 0.0,
+ f_max: Optional[float] = None,
+ n_stft: int = 201,
+ norm: Optional[str] = None,
+ mel_scale: str = "htk",
+ ) -> None:
+ super(MelScale, self).__init__()
+ self.n_mels = n_mels
+ self.sample_rate = sample_rate
+ self.f_max = f_max if f_max is not None else float(sample_rate // 2)
+ self.f_min = f_min
+ self.norm = norm
+ self.mel_scale = mel_scale
+
+ if f_min > self.f_max:
+ raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max))
+
+ fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale)
+ self.register_buffer("fb", fb)
+
+ def forward(self, specgram: Tensor) -> Tensor:
+ r"""
+ Args:
+ specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
+
+ Returns:
+ Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
+ """
+
+ # (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time)
+ mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)
+
+ return mel_specgram
+
+
+class InverseMelScale(torch.nn.Module):
+ r"""Estimate a STFT in normal frequency domain from mel frequency domain.
+
+ .. devices:: CPU CUDA
+
+ It minimizes the euclidian norm between the input mel-spectrogram and the product between
+ the estimated spectrogram and the filter banks using `torch.linalg.lstsq`.
+
+ Args:
+ n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
+ n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
+ sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
+ f_min (float, optional): Minimum frequency. (Default: ``0.``)
+ f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
+ norm (str or None, optional): If "slaney", divide the triangular mel weights by the width of the mel band
+ (area normalization). (Default: ``None``)
+ mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
+ driver (str, optional): Name of the LAPACK/MAGMA method to be used for `torch.lstsq`.
+ For CPU inputs the valid values are ``"gels"``, ``"gelsy"``, ``"gelsd"``, ``"gelss"``.
+ For CUDA input, the only valid driver is ``"gels"``, which assumes that A is full-rank.
+ (Default: ``"gels``)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> mel_spectrogram_transform = transforms.MelSpectrogram(sample_rate, n_fft=1024)
+ >>> mel_spectrogram = mel_spectrogram_transform(waveform)
+ >>> inverse_melscale_transform = transforms.InverseMelScale(n_stft=1024 // 2 + 1)
+ >>> spectrogram = inverse_melscale_transform(mel_spectrogram)
+ """
+ __constants__ = [
+ "n_stft",
+ "n_mels",
+ "sample_rate",
+ "f_min",
+ "f_max",
+ ]
+
+ def __init__(
+ self,
+ n_stft: int,
+ n_mels: int = 128,
+ sample_rate: int = 16000,
+ f_min: float = 0.0,
+ f_max: Optional[float] = None,
+ norm: Optional[str] = None,
+ mel_scale: str = "htk",
+ driver: str = "gels",
+ ) -> None:
+ super(InverseMelScale, self).__init__()
+ self.n_mels = n_mels
+ self.sample_rate = sample_rate
+ self.f_max = f_max or float(sample_rate // 2)
+ self.f_min = f_min
+ self.driver = driver
+
+ if f_min > self.f_max:
+ raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max))
+
+ if driver not in ["gels", "gelsy", "gelsd", "gelss"]:
+ raise ValueError(f'driver must be one of ["gels", "gelsy", "gelsd", "gelss"]. Found {driver}.')
+
+ fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm, mel_scale)
+ self.register_buffer("fb", fb)
+
+ def forward(self, melspec: Tensor) -> Tensor:
+ r"""
+ Args:
+ melspec (Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)
+
+ Returns:
+ Tensor: Linear scale spectrogram of size (..., freq, time)
+ """
+ # pack batch
+ shape = melspec.size()
+ melspec = melspec.view(-1, shape[-2], shape[-1])
+
+ n_mels, time = shape[-2], shape[-1]
+ freq, _ = self.fb.size() # (freq, n_mels)
+ if self.n_mels != n_mels:
+ raise ValueError("Expected an input with {} mel bins. Found: {}".format(self.n_mels, n_mels))
+
+ specgram = torch.relu(torch.linalg.lstsq(self.fb.transpose(-1, -2)[None], melspec, driver=self.driver).solution)
+
+ # unpack batch
+ specgram = specgram.view(shape[:-2] + (freq, time))
+ return specgram
+
+
+class MelSpectrogram(torch.nn.Module):
+ r"""Create MelSpectrogram for a raw audio signal.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ This is a composition of :py:func:`torchaudio.transforms.Spectrogram`
+ and :py:func:`torchaudio.transforms.MelScale`.
+
+ Sources
+ * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
+ * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
+ * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
+
+ Args:
+ sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
+ n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
+ win_length (int or None, optional): Window size. (Default: ``n_fft``)
+ hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
+ f_min (float, optional): Minimum frequency. (Default: ``0.``)
+ f_max (float or None, optional): Maximum frequency. (Default: ``None``)
+ pad (int, optional): Two sided padding of signal. (Default: ``0``)
+ n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
+ window_fn (Callable[..., Tensor], optional): A function to create a window tensor
+ that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
+ power (float, optional): Exponent for the magnitude spectrogram,
+ (must be > 0) e.g., 1 for magnitude, 2 for power, etc. (Default: ``2``)
+ normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
+ wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
+ center (bool, optional): whether to pad :attr:`waveform` on both sides so
+ that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
+ (Default: ``True``)
+ pad_mode (string, optional): controls the padding method used when
+ :attr:`center` is ``True``. (Default: ``"reflect"``)
+ onesided: Deprecated and unused.
+ norm (str or None, optional): If "slaney", divide the triangular mel weights by the width of the mel band
+ (area normalization). (Default: ``None``)
+ mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = transforms.MelSpectrogram(sample_rate)
+ >>> mel_specgram = transform(waveform) # (channel, n_mels, time)
+
+ See also:
+ :py:func:`torchaudio.functional.melscale_fbanks` - The function used to
+ generate the filter banks.
+ """
+ __constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad", "n_mels", "f_min"]
+
+ def __init__(
+ self,
+ sample_rate: int = 16000,
+ n_fft: int = 400,
+ win_length: Optional[int] = None,
+ hop_length: Optional[int] = None,
+ f_min: float = 0.0,
+ f_max: Optional[float] = None,
+ pad: int = 0,
+ n_mels: int = 128,
+ window_fn: Callable[..., Tensor] = torch.hann_window,
+ power: float = 2.0,
+ normalized: bool = False,
+ wkwargs: Optional[dict] = None,
+ center: bool = True,
+ pad_mode: str = "reflect",
+ onesided: Optional[bool] = None,
+ norm: Optional[str] = None,
+ mel_scale: str = "htk",
+ ) -> None:
+ super(MelSpectrogram, self).__init__()
+ torch._C._log_api_usage_once("torchaudio.transforms.MelSpectrogram")
+
+ if onesided is not None:
+ warnings.warn(
+ "Argument 'onesided' has been deprecated and has no influence on the behavior of this module."
+ )
+
+ self.sample_rate = sample_rate
+ self.n_fft = n_fft
+ self.win_length = win_length if win_length is not None else n_fft
+ self.hop_length = hop_length if hop_length is not None else self.win_length // 2
+ self.pad = pad
+ self.power = power
+ self.normalized = normalized
+ self.n_mels = n_mels # number of mel frequency bins
+ self.f_max = f_max
+ self.f_min = f_min
+ self.spectrogram = Spectrogram(
+ n_fft=self.n_fft,
+ win_length=self.win_length,
+ hop_length=self.hop_length,
+ pad=self.pad,
+ window_fn=window_fn,
+ power=self.power,
+ normalized=self.normalized,
+ wkwargs=wkwargs,
+ center=center,
+ pad_mode=pad_mode,
+ onesided=True,
+ )
+ self.mel_scale = MelScale(
+ self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1, norm, mel_scale
+ )
+
+ def forward(self, waveform: Tensor) -> Tensor:
+ r"""
+ Args:
+ waveform (Tensor): Tensor of audio of dimension (..., time).
+
+ Returns:
+ Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
+ """
+ specgram = self.spectrogram(waveform)
+ mel_specgram = self.mel_scale(specgram)
+ return mel_specgram
+
+
+class MFCC(torch.nn.Module):
+ r"""Create the Mel-frequency cepstrum coefficients from an audio signal.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
+ This is not the textbook implementation, but is implemented here to
+ give consistency with librosa.
+
+ This output depends on the maximum value in the input spectrogram, and so
+ may return different values for an audio clip split into snippets vs. a
+ a full clip.
+
+ Args:
+ sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
+ n_mfcc (int, optional): Number of mfc coefficients to retain. (Default: ``40``)
+ dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
+ norm (str, optional): norm to use. (Default: ``"ortho"``)
+ log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
+ melkwargs (dict or None, optional): arguments for MelSpectrogram. (Default: ``None``)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = transforms.MFCC(
+ >>> sample_rate=sample_rate,
+ >>> n_mfcc=13,
+ >>> melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 23, "center": False},
+ >>> )
+ >>> mfcc = transform(waveform)
+
+ See also:
+ :py:func:`torchaudio.functional.melscale_fbanks` - The function used to
+ generate the filter banks.
+ """
+ __constants__ = ["sample_rate", "n_mfcc", "dct_type", "top_db", "log_mels"]
+
+ def __init__(
+ self,
+ sample_rate: int = 16000,
+ n_mfcc: int = 40,
+ dct_type: int = 2,
+ norm: str = "ortho",
+ log_mels: bool = False,
+ melkwargs: Optional[dict] = None,
+ ) -> None:
+ super(MFCC, self).__init__()
+ supported_dct_types = [2]
+ if dct_type not in supported_dct_types:
+ raise ValueError("DCT type not supported: {}".format(dct_type))
+ self.sample_rate = sample_rate
+ self.n_mfcc = n_mfcc
+ self.dct_type = dct_type
+ self.norm = norm
+ self.top_db = 80.0
+ self.amplitude_to_DB = AmplitudeToDB("power", self.top_db)
+
+ melkwargs = melkwargs or {}
+ self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
+
+ if self.n_mfcc > self.MelSpectrogram.n_mels:
+ raise ValueError("Cannot select more MFCC coefficients than # mel bins")
+ dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
+ self.register_buffer("dct_mat", dct_mat)
+ self.log_mels = log_mels
+
+ def forward(self, waveform: Tensor) -> Tensor:
+ r"""
+ Args:
+ waveform (Tensor): Tensor of audio of dimension (..., time).
+
+ Returns:
+ Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
+ """
+ mel_specgram = self.MelSpectrogram(waveform)
+ if self.log_mels:
+ log_offset = 1e-6
+ mel_specgram = torch.log(mel_specgram + log_offset)
+ else:
+ mel_specgram = self.amplitude_to_DB(mel_specgram)
+
+ # (..., time, n_mels) dot (n_mels, n_mfcc) -> (..., n_nfcc, time)
+ mfcc = torch.matmul(mel_specgram.transpose(-1, -2), self.dct_mat).transpose(-1, -2)
+ return mfcc
+
+
+class LFCC(torch.nn.Module):
+ r"""Create the linear-frequency cepstrum coefficients from an audio signal.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ By default, this calculates the LFCC on the DB-scaled linear filtered spectrogram.
+ This is not the textbook implementation, but is implemented here to
+ give consistency with librosa.
+
+ This output depends on the maximum value in the input spectrogram, and so
+ may return different values for an audio clip split into snippets vs. a
+ a full clip.
+
+ Args:
+ sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
+ n_filter (int, optional): Number of linear filters to apply. (Default: ``128``)
+ n_lfcc (int, optional): Number of lfc coefficients to retain. (Default: ``40``)
+ f_min (float, optional): Minimum frequency. (Default: ``0.``)
+ f_max (float or None, optional): Maximum frequency. (Default: ``None``)
+ dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
+ norm (str, optional): norm to use. (Default: ``"ortho"``)
+ log_lf (bool, optional): whether to use log-lf spectrograms instead of db-scaled. (Default: ``False``)
+ speckwargs (dict or None, optional): arguments for Spectrogram. (Default: ``None``)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = transforms.LFCC(
+ >>> sample_rate=sample_rate,
+ >>> n_lfcc=13,
+ >>> speckwargs={"n_fft": 400, "hop_length": 160, "center": False},
+ >>> )
+ >>> lfcc = transform(waveform)
+
+ See also:
+ :py:func:`torchaudio.functional.linear_fbanks` - The function used to
+ generate the filter banks.
+ """
+ __constants__ = ["sample_rate", "n_filter", "n_lfcc", "dct_type", "top_db", "log_lf"]
+
+ def __init__(
+ self,
+ sample_rate: int = 16000,
+ n_filter: int = 128,
+ f_min: float = 0.0,
+ f_max: Optional[float] = None,
+ n_lfcc: int = 40,
+ dct_type: int = 2,
+ norm: str = "ortho",
+ log_lf: bool = False,
+ speckwargs: Optional[dict] = None,
+ ) -> None:
+ super(LFCC, self).__init__()
+ supported_dct_types = [2]
+ if dct_type not in supported_dct_types:
+ raise ValueError("DCT type not supported: {}".format(dct_type))
+ self.sample_rate = sample_rate
+ self.f_min = f_min
+ self.f_max = f_max if f_max is not None else float(sample_rate // 2)
+ self.n_filter = n_filter
+ self.n_lfcc = n_lfcc
+ self.dct_type = dct_type
+ self.norm = norm
+ self.top_db = 80.0
+ self.amplitude_to_DB = AmplitudeToDB("power", self.top_db)
+
+ speckwargs = speckwargs or {}
+ self.Spectrogram = Spectrogram(**speckwargs)
+
+ if self.n_lfcc > self.Spectrogram.n_fft:
+ raise ValueError("Cannot select more LFCC coefficients than # fft bins")
+
+ filter_mat = F.linear_fbanks(
+ n_freqs=self.Spectrogram.n_fft // 2 + 1,
+ f_min=self.f_min,
+ f_max=self.f_max,
+ n_filter=self.n_filter,
+ sample_rate=self.sample_rate,
+ )
+ self.register_buffer("filter_mat", filter_mat)
+
+ dct_mat = F.create_dct(self.n_lfcc, self.n_filter, self.norm)
+ self.register_buffer("dct_mat", dct_mat)
+ self.log_lf = log_lf
+
+ def forward(self, waveform: Tensor) -> Tensor:
+ r"""
+ Args:
+ waveform (Tensor): Tensor of audio of dimension (..., time).
+
+ Returns:
+ Tensor: Linear Frequency Cepstral Coefficients of size (..., ``n_lfcc``, time).
+ """
+ specgram = self.Spectrogram(waveform)
+
+ # (..., time, freq) dot (freq, n_filter) -> (..., n_filter, time)
+ specgram = torch.matmul(specgram.transpose(-1, -2), self.filter_mat).transpose(-1, -2)
+
+ if self.log_lf:
+ log_offset = 1e-6
+ specgram = torch.log(specgram + log_offset)
+ else:
+ specgram = self.amplitude_to_DB(specgram)
+
+ # (..., time, n_filter) dot (n_filter, n_lfcc) -> (..., n_lfcc, time)
+ lfcc = torch.matmul(specgram.transpose(-1, -2), self.dct_mat).transpose(-1, -2)
+ return lfcc
+
+
+class MuLawEncoding(torch.nn.Module):
+ r"""Encode signal based on mu-law companding.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ For more info see the
+ `Wikipedia Entry `_
+
+ This algorithm assumes the signal has been scaled to between -1 and 1 and
+ returns a signal encoded with values from 0 to quantization_channels - 1
+
+ Args:
+ quantization_channels (int, optional): Number of channels. (Default: ``256``)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = torchaudio.transforms.MuLawEncoding(quantization_channels=512)
+ >>> mulawtrans = transform(waveform)
+
+ """
+ __constants__ = ["quantization_channels"]
+
+ def __init__(self, quantization_channels: int = 256) -> None:
+ super(MuLawEncoding, self).__init__()
+ self.quantization_channels = quantization_channels
+
+ def forward(self, x: Tensor) -> Tensor:
+ r"""
+ Args:
+ x (Tensor): A signal to be encoded.
+
+ Returns:
+ Tensor: An encoded signal.
+ """
+ return F.mu_law_encoding(x, self.quantization_channels)
+
+
+class MuLawDecoding(torch.nn.Module):
+ r"""Decode mu-law encoded signal.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ For more info see the
+ `Wikipedia Entry `_
+
+ This expects an input with values between 0 and ``quantization_channels - 1``
+ and returns a signal scaled between -1 and 1.
+
+ Args:
+ quantization_channels (int, optional): Number of channels. (Default: ``256``)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = torchaudio.transforms.MuLawDecoding(quantization_channels=512)
+ >>> mulawtrans = transform(waveform)
+ """
+ __constants__ = ["quantization_channels"]
+
+ def __init__(self, quantization_channels: int = 256) -> None:
+ super(MuLawDecoding, self).__init__()
+ self.quantization_channels = quantization_channels
+
+ def forward(self, x_mu: Tensor) -> Tensor:
+ r"""
+ Args:
+ x_mu (Tensor): A mu-law encoded signal which needs to be decoded.
+
+ Returns:
+ Tensor: The signal decoded.
+ """
+ return F.mu_law_decoding(x_mu, self.quantization_channels)
+
+
+class Resample(torch.nn.Module):
+ r"""Resample a signal from one frequency to another. A resampling method can be given.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Note:
+ If resampling on waveforms of higher precision than float32, there may be a small loss of precision
+ because the kernel is cached once as float32. If high precision resampling is important for your application,
+ the functional form will retain higher precision, but run slower because it does not cache the kernel.
+ Alternatively, you could rewrite a transform that caches a higher precision kernel.
+
+ Args:
+ orig_freq (int, optional): The original frequency of the signal. (Default: ``16000``)
+ new_freq (int, optional): The desired frequency. (Default: ``16000``)
+ resampling_method (str, optional): The resampling method to use.
+ Options: [``sinc_interp_hann``, ``sinc_interp_kaiser``] (Default: ``"sinc_interp_hann"``)
+ lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
+ but less efficient. (Default: ``6``)
+ rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
+ Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
+ beta (float or None, optional): The shape parameter used for kaiser window.
+ dtype (torch.device, optional):
+ Determnines the precision that resampling kernel is pre-computed and cached. If not provided,
+ kernel is computed with ``torch.float64`` then cached as ``torch.float32``.
+ If you need higher precision, provide ``torch.float64``, and the pre-computed kernel is computed and
+ cached as ``torch.float64``. If you use resample with lower precision, then instead of providing this
+ providing this argument, please use ``Resample.to(dtype)``, so that the kernel generation is still
+ carried out on ``torch.float64``.
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = transforms.Resample(sample_rate, sample_rate/10)
+ >>> waveform = transform(waveform)
+ """
+
+ def __init__(
+ self,
+ orig_freq: int = 16000,
+ new_freq: int = 16000,
+ resampling_method: str = "sinc_interp_hann",
+ lowpass_filter_width: int = 6,
+ rolloff: float = 0.99,
+ beta: Optional[float] = None,
+ *,
+ dtype: Optional[torch.dtype] = None,
+ ) -> None:
+ super().__init__()
+
+ self.orig_freq = orig_freq
+ self.new_freq = new_freq
+ self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
+ self.resampling_method = resampling_method
+ self.lowpass_filter_width = lowpass_filter_width
+ self.rolloff = rolloff
+ self.beta = beta
+
+ if self.orig_freq != self.new_freq:
+ kernel, self.width = _get_sinc_resample_kernel(
+ self.orig_freq,
+ self.new_freq,
+ self.gcd,
+ self.lowpass_filter_width,
+ self.rolloff,
+ self.resampling_method,
+ beta,
+ dtype=dtype,
+ )
+ self.register_buffer("kernel", kernel)
+
+ def forward(self, waveform: Tensor) -> Tensor:
+ r"""
+ Args:
+ waveform (Tensor): Tensor of audio of dimension (..., time).
+
+ Returns:
+ Tensor: Output signal of dimension (..., time).
+ """
+ if self.orig_freq == self.new_freq:
+ return waveform
+ return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd, self.kernel, self.width)
+
+
+class ComputeDeltas(torch.nn.Module):
+ r"""Compute delta coefficients of a tensor, usually a spectrogram.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ See `torchaudio.functional.compute_deltas` for more details.
+
+ Args:
+ win_length (int, optional): The window length used for computing delta. (Default: ``5``)
+ mode (str, optional): Mode parameter passed to padding. (Default: ``"replicate"``)
+ """
+ __constants__ = ["win_length"]
+
+ def __init__(self, win_length: int = 5, mode: str = "replicate") -> None:
+ super(ComputeDeltas, self).__init__()
+ self.win_length = win_length
+ self.mode = mode
+
+ def forward(self, specgram: Tensor) -> Tensor:
+ r"""
+ Args:
+ specgram (Tensor): Tensor of audio of dimension (..., freq, time).
+
+ Returns:
+ Tensor: Tensor of deltas of dimension (..., freq, time).
+ """
+ return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
+
+
+class TimeStretch(torch.nn.Module):
+ r"""Stretch stft in time without modifying pitch for a given rate.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Proposed in *SpecAugment* :cite:`specaugment`.
+
+ Args:
+ hop_length (int or None, optional): Length of hop between STFT windows.
+ (Default: ``n_fft // 2``, where ``n_fft == (n_freq - 1) * 2``)
+ n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
+ fixed_rate (float or None, optional): rate to speed up or slow down by.
+ If None is provided, rate must be passed to the forward method. (Default: ``None``)
+
+ .. note::
+
+ The expected input is raw, complex-valued spectrogram.
+
+ Example
+ >>> spectrogram = torchaudio.transforms.Spectrogram(power=None)
+ >>> stretch = torchaudio.transforms.TimeStretch()
+ >>>
+ >>> original = spectrogram(waveform)
+ >>> stretched_1_2 = stretch(original, 1.2)
+ >>> stretched_0_9 = stretch(original, 0.9)
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch.png
+ :width: 600
+ :alt: The visualization of stretched spectrograms.
+ """
+ __constants__ = ["fixed_rate"]
+
+ def __init__(self, hop_length: Optional[int] = None, n_freq: int = 201, fixed_rate: Optional[float] = None) -> None:
+ super(TimeStretch, self).__init__()
+
+ self.fixed_rate = fixed_rate
+
+ n_fft = (n_freq - 1) * 2
+ hop_length = hop_length if hop_length is not None else n_fft // 2
+ self.register_buffer("phase_advance", torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
+
+ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor:
+ r"""
+ Args:
+ complex_specgrams (Tensor):
+ A tensor of dimension `(..., freq, num_frame)` with complex dtype.
+ overriding_rate (float or None, optional): speed up to apply to this batch.
+ If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
+
+ Returns:
+ Tensor:
+ Stretched spectrogram. The resulting tensor is of the corresponding complex dtype
+ as the input spectrogram, and the number of frames is changed to ``ceil(num_frame / rate)``.
+ """
+ if not torch.is_complex(complex_specgrams):
+ warnings.warn(
+ "The input to TimeStretch must be complex type. "
+ "Providing non-complex tensor produces invalid results.",
+ stacklevel=4,
+ )
+
+ if overriding_rate is None:
+ if self.fixed_rate is None:
+ raise ValueError("If no fixed_rate is specified, must pass a valid rate to the forward method.")
+ rate = self.fixed_rate
+ else:
+ rate = overriding_rate
+ return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
+
+
+class Fade(torch.nn.Module):
+ r"""Add a fade in and/or fade out to an waveform.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ fade_in_len (int, optional): Length of fade-in (time frames). (Default: ``0``)
+ fade_out_len (int, optional): Length of fade-out (time frames). (Default: ``0``)
+ fade_shape (str, optional): Shape of fade. Must be one of: "quarter_sine",
+ ``"half_sine"``, ``"linear"``, ``"logarithmic"``, ``"exponential"``.
+ (Default: ``"linear"``)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = transforms.Fade(fade_in_len=sample_rate, fade_out_len=2 * sample_rate, fade_shape="linear")
+ >>> faded_waveform = transform(waveform)
+ """
+
+ def __init__(self, fade_in_len: int = 0, fade_out_len: int = 0, fade_shape: str = "linear") -> None:
+ super(Fade, self).__init__()
+ self.fade_in_len = fade_in_len
+ self.fade_out_len = fade_out_len
+ self.fade_shape = fade_shape
+
+ def forward(self, waveform: Tensor) -> Tensor:
+ r"""
+ Args:
+ waveform (Tensor): Tensor of audio of dimension `(..., time)`.
+
+ Returns:
+ Tensor: Tensor of audio of dimension `(..., time)`.
+ """
+ waveform_length = waveform.size()[-1]
+ device = waveform.device
+ return self._fade_in(waveform_length, device) * self._fade_out(waveform_length, device) * waveform
+
+ def _fade_in(self, waveform_length: int, device: torch.device) -> Tensor:
+ fade = torch.linspace(0, 1, self.fade_in_len, device=device)
+ ones = torch.ones(waveform_length - self.fade_in_len, device=device)
+
+ if self.fade_shape == "linear":
+ fade = fade
+
+ if self.fade_shape == "exponential":
+ fade = torch.pow(2, (fade - 1)) * fade
+
+ if self.fade_shape == "logarithmic":
+ fade = torch.log10(0.1 + fade) + 1
+
+ if self.fade_shape == "quarter_sine":
+ fade = torch.sin(fade * math.pi / 2)
+
+ if self.fade_shape == "half_sine":
+ fade = torch.sin(fade * math.pi - math.pi / 2) / 2 + 0.5
+
+ return torch.cat((fade, ones)).clamp_(0, 1)
+
+ def _fade_out(self, waveform_length: int, device: torch.device) -> Tensor:
+ fade = torch.linspace(0, 1, self.fade_out_len, device=device)
+ ones = torch.ones(waveform_length - self.fade_out_len, device=device)
+
+ if self.fade_shape == "linear":
+ fade = -fade + 1
+
+ if self.fade_shape == "exponential":
+ fade = torch.pow(2, -fade) * (1 - fade)
+
+ if self.fade_shape == "logarithmic":
+ fade = torch.log10(1.1 - fade) + 1
+
+ if self.fade_shape == "quarter_sine":
+ fade = torch.sin(fade * math.pi / 2 + math.pi / 2)
+
+ if self.fade_shape == "half_sine":
+ fade = torch.sin(fade * math.pi + math.pi / 2) / 2 + 0.5
+
+ return torch.cat((ones, fade)).clamp_(0, 1)
+
+
+class _AxisMasking(torch.nn.Module):
+ r"""Apply masking to a spectrogram.
+
+ Args:
+ mask_param (int): Maximum possible length of the mask.
+ axis (int): What dimension the mask is applied on (assuming the tensor is 3D).
+ For frequency masking, axis = 1.
+ For time masking, axis = 2.
+ iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
+ This option is applicable only when the dimension of the input tensor is >= 3.
+ p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
+ """
+ __constants__ = ["mask_param", "axis", "iid_masks", "p"]
+
+ def __init__(self, mask_param: int, axis: int, iid_masks: bool, p: float = 1.0) -> None:
+ super(_AxisMasking, self).__init__()
+ self.mask_param = mask_param
+ self.axis = axis
+ self.iid_masks = iid_masks
+ self.p = p
+
+ def forward(self, specgram: Tensor, mask_value: float = 0.0) -> Tensor:
+ r"""
+ Args:
+ specgram (Tensor): Tensor of dimension `(..., freq, time)`.
+ mask_value (float): Value to assign to the masked columns.
+
+ Returns:
+ Tensor: Masked spectrogram of dimensions `(..., freq, time)`.
+ """
+ # if iid_masks flag marked and specgram has a batch dimension
+ # self.axis + specgram.dim() - 3 gives the time/frequency dimension (last two dimensions)
+ # for input tensor for which the dimension is not 3.
+ if self.iid_masks:
+ return F.mask_along_axis_iid(
+ specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p
+ )
+ else:
+ return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p)
+
+
+class FrequencyMasking(_AxisMasking):
+ r"""Apply masking to a spectrogram in the frequency domain.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Proposed in *SpecAugment* :cite:`specaugment`.
+
+ Args:
+ freq_mask_param (int): maximum possible length of the mask.
+ Indices uniformly sampled from [0, freq_mask_param).
+ iid_masks (bool, optional): whether to apply different masks to each
+ example/channel in the batch. (Default: ``False``)
+ This option is applicable only when the input tensor >= 3D.
+
+ Example
+ >>> spectrogram = torchaudio.transforms.Spectrogram()
+ >>> masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=80)
+ >>>
+ >>> original = spectrogram(waveform)
+ >>> masked = masking(original)
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_freq_masking1.png
+ :alt: The original spectrogram
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_freq_masking2.png
+ :alt: The spectrogram masked along frequency axis
+ """
+
+ def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
+ super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)
+
+
+class TimeMasking(_AxisMasking):
+ r"""Apply masking to a spectrogram in the time domain.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Proposed in *SpecAugment* :cite:`specaugment`.
+
+ Args:
+ time_mask_param (int): maximum possible length of the mask.
+ Indices uniformly sampled from [0, time_mask_param).
+ iid_masks (bool, optional): whether to apply different masks to each
+ example/channel in the batch. (Default: ``False``)
+ This option is applicable only when the input tensor >= 3D.
+ p (float, optional): maximum proportion of time steps that can be masked.
+ Must be within range [0.0, 1.0]. (Default: 1.0)
+
+ Example
+ >>> spectrogram = torchaudio.transforms.Spectrogram()
+ >>> masking = torchaudio.transforms.TimeMasking(time_mask_param=80)
+ >>>
+ >>> original = spectrogram(waveform)
+ >>> masked = masking(original)
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_masking1.png
+ :alt: The original spectrogram
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_masking2.png
+ :alt: The spectrogram masked along time axis
+ """
+
+ def __init__(self, time_mask_param: int, iid_masks: bool = False, p: float = 1.0) -> None:
+ if not 0.0 <= p <= 1.0:
+ raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
+ super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks, p=p)
+
+
+class SpecAugment(torch.nn.Module):
+ r"""Apply time and frequency masking to a spectrogram.
+ Args:
+ n_time_masks (int): Number of time masks. If its value is zero, no time masking will be applied.
+ time_mask_param (int): Maximum possible length of the time mask.
+ n_freq_masks (int): Number of frequency masks. If its value is zero, no frequency masking will be applied.
+ freq_mask_param (int): Maximum possible length of the frequency mask.
+ iid_masks (bool, optional): Applies iid masks to each of the examples in the batch dimension.
+ This option is applicable only when the input tensor is 4D. (Default: ``True``)
+ p (float, optional): maximum proportion of time steps that can be masked.
+ Must be within range [0.0, 1.0]. (Default: 1.0)
+ zero_masking (bool, optional): If ``True``, use 0 as the mask value,
+ else use mean of the input tensor. (Default: ``False``)
+ """
+ __constants__ = [
+ "n_time_masks",
+ "time_mask_param",
+ "n_freq_masks",
+ "freq_mask_param",
+ "iid_masks",
+ "p",
+ "zero_masking",
+ ]
+
+ def __init__(
+ self,
+ n_time_masks: int,
+ time_mask_param: int,
+ n_freq_masks: int,
+ freq_mask_param: int,
+ iid_masks: bool = True,
+ p: float = 1.0,
+ zero_masking: bool = False,
+ ) -> None:
+ super(SpecAugment, self).__init__()
+ self.n_time_masks = n_time_masks
+ self.time_mask_param = time_mask_param
+ self.n_freq_masks = n_freq_masks
+ self.freq_mask_param = freq_mask_param
+ self.iid_masks = iid_masks
+ self.p = p
+ self.zero_masking = zero_masking
+
+ def forward(self, specgram: Tensor) -> Tensor:
+ r"""
+ Args:
+ specgram (Tensor): Tensor of shape `(..., freq, time)`.
+ Returns:
+ Tensor: Masked spectrogram of shape `(..., freq, time)`.
+ """
+ if self.zero_masking:
+ mask_value = 0.0
+ else:
+ mask_value = specgram.mean()
+ time_dim = specgram.dim() - 1
+ freq_dim = time_dim - 1
+
+ if specgram.dim() > 2 and self.iid_masks is True:
+ for _ in range(self.n_time_masks):
+ specgram = F.mask_along_axis_iid(specgram, self.time_mask_param, mask_value, time_dim, p=self.p)
+ for _ in range(self.n_freq_masks):
+ specgram = F.mask_along_axis_iid(specgram, self.freq_mask_param, mask_value, freq_dim, p=self.p)
+ else:
+ for _ in range(self.n_time_masks):
+ specgram = F.mask_along_axis(specgram, self.time_mask_param, mask_value, time_dim, p=self.p)
+ for _ in range(self.n_freq_masks):
+ specgram = F.mask_along_axis(specgram, self.freq_mask_param, mask_value, freq_dim, p=self.p)
+
+ return specgram
+
+
+class Loudness(torch.nn.Module):
+ r"""Measure audio loudness according to the ITU-R BS.1770-4 recommendation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ Args:
+ sample_rate (int): Sample rate of audio signal.
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = transforms.Loudness(sample_rate)
+ >>> loudness = transform(waveform)
+
+ Reference:
+ - https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
+ """
+ __constants__ = ["sample_rate"]
+
+ def __init__(self, sample_rate: int):
+ super(Loudness, self).__init__()
+ self.sample_rate = sample_rate
+
+ def forward(self, wavefrom: Tensor):
+ r"""
+ Args:
+ waveform(torch.Tensor): audio waveform of dimension `(..., channels, time)`
+
+ Returns:
+ Tensor: loudness estimates (LKFS)
+ """
+ return F.loudness(wavefrom, self.sample_rate)
+
+
+class Vol(torch.nn.Module):
+ r"""Adjust volume of waveform.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ gain (float): Interpreted according to the given gain_type:
+ If ``gain_type`` = ``amplitude``, ``gain`` is a positive amplitude ratio.
+ If ``gain_type`` = ``power``, ``gain`` is a power (voltage squared).
+ If ``gain_type`` = ``db``, ``gain`` is in decibels.
+ gain_type (str, optional): Type of gain. One of: ``amplitude``, ``power``, ``db`` (Default: ``amplitude``)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = transforms.Vol(gain=0.5, gain_type="amplitude")
+ >>> quieter_waveform = transform(waveform)
+ """
+
+ def __init__(self, gain: float, gain_type: str = "amplitude"):
+ super(Vol, self).__init__()
+ self.gain = gain
+ self.gain_type = gain_type
+
+ if gain_type in ["amplitude", "power"] and gain < 0:
+ raise ValueError("If gain_type = amplitude or power, gain must be positive.")
+
+ def forward(self, waveform: Tensor) -> Tensor:
+ r"""
+ Args:
+ waveform (Tensor): Tensor of audio of dimension `(..., time)`.
+
+ Returns:
+ Tensor: Tensor of audio of dimension `(..., time)`.
+ """
+ if self.gain_type == "amplitude":
+ waveform = waveform * self.gain
+
+ if self.gain_type == "db":
+ waveform = F.gain(waveform, self.gain)
+
+ if self.gain_type == "power":
+ waveform = F.gain(waveform, 10 * math.log10(self.gain))
+
+ return torch.clamp(waveform, -1, 1)
+
+
+class SlidingWindowCmn(torch.nn.Module):
+ r"""
+ Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
+ min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start).
+ Only applicable if center == false, ignored if center==true (int, default = 100)
+ center (bool, optional): If true, use a window centered on the current frame
+ (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
+ norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = transforms.SlidingWindowCmn(cmn_window=1000)
+ >>> cmn_waveform = transform(waveform)
+ """
+
+ def __init__(
+ self, cmn_window: int = 600, min_cmn_window: int = 100, center: bool = False, norm_vars: bool = False
+ ) -> None:
+ super().__init__()
+ self.cmn_window = cmn_window
+ self.min_cmn_window = min_cmn_window
+ self.center = center
+ self.norm_vars = norm_vars
+
+ def forward(self, specgram: Tensor) -> Tensor:
+ r"""
+ Args:
+ specgram (Tensor): Tensor of spectrogram of dimension `(..., time, freq)`.
+
+ Returns:
+ Tensor: Tensor of spectrogram of dimension `(..., time, freq)`.
+ """
+ cmn_specgram = F.sliding_window_cmn(specgram, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
+ return cmn_specgram
+
+
+class Vad(torch.nn.Module):
+ r"""Voice Activity Detector. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ Attempts to trim silence and quiet background sounds from the ends of recordings of speech.
+ The algorithm currently uses a simple cepstral power measurement to detect voice,
+ so may be fooled by other things, especially music.
+
+ The effect can trim only from the front of the audio,
+ so in order to trim from the back, the reverse effect must also be used.
+
+ Args:
+ sample_rate (int): Sample rate of audio signal.
+ trigger_level (float, optional): The measurement level used to trigger activity detection.
+ This may need to be changed depending on the noise level, signal level,
+ and other characteristics of the input audio. (Default: 7.0)
+ trigger_time (float, optional): The time constant (in seconds)
+ used to help ignore short bursts of sound. (Default: 0.25)
+ search_time (float, optional): The amount of audio (in seconds)
+ to search for quieter/shorter bursts of audio to include prior
+ to the detected trigger point. (Default: 1.0)
+ allowed_gap (float, optional): The allowed gap (in seconds) between
+ quiteter/shorter bursts of audio to include prior
+ to the detected trigger point. (Default: 0.25)
+ pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve
+ before the trigger point and any found quieter/shorter bursts. (Default: 0.0)
+ boot_time (float, optional) The algorithm (internally) uses adaptive noise
+ estimation/reduction in order to detect the start of the wanted audio.
+ This option sets the time for the initial noise estimate. (Default: 0.35)
+ noise_up_time (float, optional) Time constant used by the adaptive noise estimator
+ for when the noise level is increasing. (Default: 0.1)
+ noise_down_time (float, optional) Time constant used by the adaptive noise estimator
+ for when the noise level is decreasing. (Default: 0.01)
+ noise_reduction_amount (float, optional) Amount of noise reduction to use in
+ the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35)
+ measure_freq (float, optional) Frequency of the algorithm’s
+ processing/measurements. (Default: 20.0)
+ measure_duration: (float or None, optional) Measurement duration.
+ (Default: Twice the measurement period; i.e. with overlap.)
+ measure_smooth_time (float, optional) Time constant used to smooth
+ spectral measurements. (Default: 0.4)
+ hp_filter_freq (float, optional) "Brick-wall" frequency of high-pass filter applied
+ at the input to the detector algorithm. (Default: 50.0)
+ lp_filter_freq (float, optional) "Brick-wall" frequency of low-pass filter applied
+ at the input to the detector algorithm. (Default: 6000.0)
+ hp_lifter_freq (float, optional) "Brick-wall" frequency of high-pass lifter used
+ in the detector algorithm. (Default: 150.0)
+ lp_lifter_freq (float, optional) "Brick-wall" frequency of low-pass lifter used
+ in the detector algorithm. (Default: 2000.0)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> waveform_reversed, sample_rate = apply_effects_tensor(waveform, sample_rate, [["reverse"]])
+ >>> transform = transforms.Vad(sample_rate=sample_rate, trigger_level=7.5)
+ >>> waveform_reversed_front_trim = transform(waveform_reversed)
+ >>> waveform_end_trim, sample_rate = apply_effects_tensor(
+ >>> waveform_reversed_front_trim, sample_rate, [["reverse"]]
+ >>> )
+
+ Reference:
+ - http://sox.sourceforge.net/sox.html
+ """
+
+ def __init__(
+ self,
+ sample_rate: int,
+ trigger_level: float = 7.0,
+ trigger_time: float = 0.25,
+ search_time: float = 1.0,
+ allowed_gap: float = 0.25,
+ pre_trigger_time: float = 0.0,
+ boot_time: float = 0.35,
+ noise_up_time: float = 0.1,
+ noise_down_time: float = 0.01,
+ noise_reduction_amount: float = 1.35,
+ measure_freq: float = 20.0,
+ measure_duration: Optional[float] = None,
+ measure_smooth_time: float = 0.4,
+ hp_filter_freq: float = 50.0,
+ lp_filter_freq: float = 6000.0,
+ hp_lifter_freq: float = 150.0,
+ lp_lifter_freq: float = 2000.0,
+ ) -> None:
+ super().__init__()
+
+ self.sample_rate = sample_rate
+ self.trigger_level = trigger_level
+ self.trigger_time = trigger_time
+ self.search_time = search_time
+ self.allowed_gap = allowed_gap
+ self.pre_trigger_time = pre_trigger_time
+ self.boot_time = boot_time
+ self.noise_up_time = noise_up_time
+ self.noise_down_time = noise_down_time
+ self.noise_reduction_amount = noise_reduction_amount
+ self.measure_freq = measure_freq
+ self.measure_duration = measure_duration
+ self.measure_smooth_time = measure_smooth_time
+ self.hp_filter_freq = hp_filter_freq
+ self.lp_filter_freq = lp_filter_freq
+ self.hp_lifter_freq = hp_lifter_freq
+ self.lp_lifter_freq = lp_lifter_freq
+
+ def forward(self, waveform: Tensor) -> Tensor:
+ r"""
+ Args:
+ waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)`
+ Tensor of shape `(channels, time)` is treated as a multi-channel recording
+ of the same event and the resulting output will be trimmed to the earliest
+ voice activity in any channel.
+ """
+ return F.vad(
+ waveform=waveform,
+ sample_rate=self.sample_rate,
+ trigger_level=self.trigger_level,
+ trigger_time=self.trigger_time,
+ search_time=self.search_time,
+ allowed_gap=self.allowed_gap,
+ pre_trigger_time=self.pre_trigger_time,
+ boot_time=self.boot_time,
+ noise_up_time=self.noise_up_time,
+ noise_down_time=self.noise_down_time,
+ noise_reduction_amount=self.noise_reduction_amount,
+ measure_freq=self.measure_freq,
+ measure_duration=self.measure_duration,
+ measure_smooth_time=self.measure_smooth_time,
+ hp_filter_freq=self.hp_filter_freq,
+ lp_filter_freq=self.lp_filter_freq,
+ hp_lifter_freq=self.hp_lifter_freq,
+ lp_lifter_freq=self.lp_lifter_freq,
+ )
+
+
+class SpectralCentroid(torch.nn.Module):
+ r"""Compute the spectral centroid for each channel along the time axis.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ The spectral centroid is defined as the weighted average of the
+ frequency values, weighted by their magnitude.
+
+ Args:
+ sample_rate (int): Sample rate of audio signal.
+ n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
+ win_length (int or None, optional): Window size. (Default: ``n_fft``)
+ hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
+ pad (int, optional): Two sided padding of signal. (Default: ``0``)
+ window_fn (Callable[..., Tensor], optional): A function to create a window tensor
+ that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
+ wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = transforms.SpectralCentroid(sample_rate)
+ >>> spectral_centroid = transform(waveform) # (channel, time)
+ """
+ __constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad"]
+
+ def __init__(
+ self,
+ sample_rate: int,
+ n_fft: int = 400,
+ win_length: Optional[int] = None,
+ hop_length: Optional[int] = None,
+ pad: int = 0,
+ window_fn: Callable[..., Tensor] = torch.hann_window,
+ wkwargs: Optional[dict] = None,
+ ) -> None:
+ super(SpectralCentroid, self).__init__()
+ self.sample_rate = sample_rate
+ self.n_fft = n_fft
+ self.win_length = win_length if win_length is not None else n_fft
+ self.hop_length = hop_length if hop_length is not None else self.win_length // 2
+ window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
+ self.register_buffer("window", window)
+ self.pad = pad
+
+ def forward(self, waveform: Tensor) -> Tensor:
+ r"""
+ Args:
+ waveform (Tensor): Tensor of audio of dimension `(..., time)`.
+
+ Returns:
+ Tensor: Spectral Centroid of size `(..., time)`.
+ """
+
+ return F.spectral_centroid(
+ waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length, self.win_length
+ )
+
+
+class PitchShift(LazyModuleMixin, torch.nn.Module):
+ r"""Shift the pitch of a waveform by ``n_steps`` steps.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ Args:
+ waveform (Tensor): The input waveform of shape `(..., time)`.
+ sample_rate (int): Sample rate of `waveform`.
+ n_steps (int): The (fractional) steps to shift `waveform`.
+ bins_per_octave (int, optional): The number of steps per octave (Default : ``12``).
+ n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins (Default: ``512``).
+ win_length (int or None, optional): Window size. If None, then ``n_fft`` is used. (Default: ``None``).
+ hop_length (int or None, optional): Length of hop between STFT windows. If None, then ``win_length // 4``
+ is used (Default: ``None``).
+ window (Tensor or None, optional): Window tensor that is applied/multiplied to each frame/window.
+ If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``).
+
+ Example
+ >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
+ >>> transform = transforms.PitchShift(sample_rate, 4)
+ >>> waveform_shift = transform(waveform) # (channel, time)
+ """
+ __constants__ = ["sample_rate", "n_steps", "bins_per_octave", "n_fft", "win_length", "hop_length"]
+
+ kernel: UninitializedParameter
+ width: int
+
+ def __init__(
+ self,
+ sample_rate: int,
+ n_steps: int,
+ bins_per_octave: int = 12,
+ n_fft: int = 512,
+ win_length: Optional[int] = None,
+ hop_length: Optional[int] = None,
+ window_fn: Callable[..., Tensor] = torch.hann_window,
+ wkwargs: Optional[dict] = None,
+ ) -> None:
+ super().__init__()
+ self.n_steps = n_steps
+ self.bins_per_octave = bins_per_octave
+ self.sample_rate = sample_rate
+ self.n_fft = n_fft
+ self.win_length = win_length if win_length is not None else n_fft
+ self.hop_length = hop_length if hop_length is not None else self.win_length // 4
+ window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
+ self.register_buffer("window", window)
+ rate = 2.0 ** (-float(n_steps) / bins_per_octave)
+ self.orig_freq = int(sample_rate / rate)
+ self.gcd = math.gcd(int(self.orig_freq), int(sample_rate))
+
+ if self.orig_freq != sample_rate:
+ self.width = -1
+ self.kernel = UninitializedParameter(device=None, dtype=None)
+
+ def initialize_parameters(self, input):
+ if self.has_uninitialized_params():
+ if self.orig_freq != self.sample_rate:
+ with torch.no_grad():
+ kernel, self.width = _get_sinc_resample_kernel(
+ self.orig_freq,
+ self.sample_rate,
+ self.gcd,
+ dtype=input.dtype,
+ device=input.device,
+ )
+ self.kernel.materialize(kernel.shape)
+ self.kernel.copy_(kernel)
+
+ def forward(self, waveform: Tensor) -> Tensor:
+ r"""
+ Args:
+ waveform (Tensor): Tensor of audio of dimension `(..., time)`.
+
+ Returns:
+ Tensor: The pitch-shifted audio of shape `(..., time)`.
+ """
+ shape = waveform.size()
+
+ waveform_stretch = _stretch_waveform(
+ waveform,
+ self.n_steps,
+ self.bins_per_octave,
+ self.n_fft,
+ self.win_length,
+ self.hop_length,
+ self.window,
+ )
+
+ if self.orig_freq != self.sample_rate:
+ waveform_shift = _apply_sinc_resample_kernel(
+ waveform_stretch,
+ self.orig_freq,
+ self.sample_rate,
+ self.gcd,
+ self.kernel,
+ self.width,
+ )
+ else:
+ waveform_shift = waveform_stretch
+
+ return _fix_waveform_shape(
+ waveform_shift,
+ shape,
+ )
+
+
+class RNNTLoss(torch.nn.Module):
+ """Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
+ :cite:`graves2012sequence`.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ The RNN Transducer loss extends the CTC loss by defining a distribution over output
+ sequences of all lengths, and by jointly modelling both input-output and output-output
+ dependencies.
+
+ Args:
+ blank (int, optional): blank label (Default: ``-1``)
+ clamp (float, optional): clamp for gradients (Default: ``-1``)
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``"none"`` | ``"mean"`` | ``"sum"``. (Default: ``"mean"``)
+ fused_log_softmax (bool): set to False if calling log_softmax outside of loss (Default: ``True``)
+
+ Example
+ >>> # Hypothetical values
+ >>> logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
+ >>> [0.1, 0.1, 0.6, 0.1, 0.1],
+ >>> [0.1, 0.1, 0.2, 0.8, 0.1]],
+ >>> [[0.1, 0.6, 0.1, 0.1, 0.1],
+ >>> [0.1, 0.1, 0.2, 0.1, 0.1],
+ >>> [0.7, 0.1, 0.2, 0.1, 0.1]]]],
+ >>> dtype=torch.float32,
+ >>> requires_grad=True)
+ >>> targets = torch.tensor([[1, 2]], dtype=torch.int)
+ >>> logit_lengths = torch.tensor([2], dtype=torch.int)
+ >>> target_lengths = torch.tensor([2], dtype=torch.int)
+ >>> transform = transforms.RNNTLoss(blank=0)
+ >>> loss = transform(logits, targets, logit_lengths, target_lengths)
+ >>> loss.backward()
+ """
+
+ def __init__(
+ self,
+ blank: int = -1,
+ clamp: float = -1.0,
+ reduction: str = "mean",
+ fused_log_softmax: bool = True,
+ ):
+ super().__init__()
+ self.blank = blank
+ self.clamp = clamp
+ self.reduction = reduction
+ self.fused_log_softmax = fused_log_softmax
+
+ def forward(
+ self,
+ logits: Tensor,
+ targets: Tensor,
+ logit_lengths: Tensor,
+ target_lengths: Tensor,
+ ):
+ """
+ Args:
+ logits (Tensor): Tensor of dimension `(batch, max seq length, max target length + 1, class)`
+ containing output from joiner
+ targets (Tensor): Tensor of dimension `(batch, max target length)` containing targets with zero padded
+ logit_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of each sequence from encoder
+ target_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of targets for each sequence
+ Returns:
+ Tensor: Loss with the reduction option applied. If ``reduction`` is ``"none"``, then size (batch),
+ otherwise scalar.
+ """
+ return F.rnnt_loss(
+ logits,
+ targets,
+ logit_lengths,
+ target_lengths,
+ self.blank,
+ self.clamp,
+ self.reduction,
+ self.fused_log_softmax,
+ )
+
+
+class Convolve(torch.nn.Module):
+ r"""
+ Convolves inputs along their last dimension using the direct method.
+ Note that, in contrast to :class:`torch.nn.Conv1d`, which actually applies the valid cross-correlation
+ operator, this module applies the true `convolution`_ operator.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ mode (str, optional): Must be one of ("full", "valid", "same").
+
+ * "full": Returns the full convolution result, with shape `(..., N + M - 1)`, where
+ `N` and `M` are the trailing dimensions of the two inputs. (Default)
+ * "valid": Returns the segment of the full convolution result corresponding to where
+ the two inputs overlap completely, with shape `(..., max(N, M) - min(N, M) + 1)`.
+ * "same": Returns the center segment of the full convolution result, with shape `(..., N)`.
+
+ .. _convolution:
+ https://en.wikipedia.org/wiki/Convolution
+ """
+
+ def __init__(self, mode: str = "full") -> None:
+ _check_convolve_mode(mode)
+
+ super().__init__()
+ self.mode = mode
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ r"""
+ Args:
+ x (torch.Tensor): First convolution operand, with shape `(..., N)`.
+ y (torch.Tensor): Second convolution operand, with shape `(..., M)`
+ (leading dimensions must be broadcast-able with those of ``x``).
+
+ Returns:
+ torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where
+ the leading dimensions match those of ``x`` and `L` is dictated by ``mode``.
+ """
+ return F.convolve(x, y, mode=self.mode)
+
+
+class FFTConvolve(torch.nn.Module):
+ r"""
+ Convolves inputs along their last dimension using FFT. For inputs with large last dimensions, this module
+ is generally much faster than :class:`Convolve`.
+ Note that, in contrast to :class:`torch.nn.Conv1d`, which actually applies the valid cross-correlation
+ operator, this module applies the true `convolution`_ operator.
+ Also note that this module can only output float tensors (int tensor inputs will be cast to float).
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ mode (str, optional): Must be one of ("full", "valid", "same").
+
+ * "full": Returns the full convolution result, with shape `(..., N + M - 1)`, where
+ `N` and `M` are the trailing dimensions of the two inputs. (Default)
+ * "valid": Returns the segment of the full convolution result corresponding to where
+ the two inputs overlap completely, with shape `(..., max(N, M) - min(N, M) + 1)`.
+ * "same": Returns the center segment of the full convolution result, with shape `(..., N)`.
+
+ .. _convolution:
+ https://en.wikipedia.org/wiki/Convolution
+ """
+
+ def __init__(self, mode: str = "full") -> None:
+ _check_convolve_mode(mode)
+
+ super().__init__()
+ self.mode = mode
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ r"""
+ Args:
+ x (torch.Tensor): First convolution operand, with shape `(..., N)`.
+ y (torch.Tensor): Second convolution operand, with shape `(..., M)`
+ (leading dimensions must be broadcast-able with those of ``x``).
+
+ Returns:
+ torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where
+ the leading dimensions match those of ``x`` and `L` is dictated by ``mode``.
+ """
+ return F.fftconvolve(x, y, mode=self.mode)
+
+
+def _source_target_sample_rate(orig_freq: int, speed: float) -> Tuple[int, int]:
+ source_sample_rate = int(speed * orig_freq)
+ target_sample_rate = int(orig_freq)
+ gcd = math.gcd(source_sample_rate, target_sample_rate)
+ return source_sample_rate // gcd, target_sample_rate // gcd
+
+
+class Speed(torch.nn.Module):
+ r"""Adjusts waveform speed.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ orig_freq (int): Original frequency of the signals in ``waveform``.
+ factor (float): Factor by which to adjust speed of input. Values greater than 1.0
+ compress ``waveform`` in time, whereas values less than 1.0 stretch ``waveform`` in time.
+ """
+
+ def __init__(self, orig_freq, factor) -> None:
+ super().__init__()
+
+ self.orig_freq = orig_freq
+ self.factor = factor
+
+ self.source_sample_rate, self.target_sample_rate = _source_target_sample_rate(orig_freq, factor)
+ self.resampler = Resample(orig_freq=self.source_sample_rate, new_freq=self.target_sample_rate)
+
+ def forward(self, waveform, lengths: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ r"""
+ Args:
+ waveform (torch.Tensor): Input signals, with shape `(..., time)`.
+ lengths (torch.Tensor or None, optional): Valid lengths of signals in ``waveform``, with shape `(...)`.
+ If ``None``, all elements in ``waveform`` are treated as valid. (Default: ``None``)
+
+ Returns:
+ (torch.Tensor, torch.Tensor or None):
+ torch.Tensor
+ Speed-adjusted waveform, with shape `(..., new_time).`
+ torch.Tensor or None
+ If ``lengths`` is not ``None``, valid lengths of signals in speed-adjusted waveform,
+ with shape `(...)`; otherwise, ``None``.
+ """
+
+ if lengths is None:
+ out_lengths = None
+ else:
+ out_lengths = torch.ceil(lengths * self.target_sample_rate / self.source_sample_rate).to(lengths.dtype)
+
+ return self.resampler(waveform), out_lengths
+
+
+class SpeedPerturbation(torch.nn.Module):
+ r"""Applies the speed perturbation augmentation introduced in
+ *Audio augmentation for speech recognition* :cite:`ko15_interspeech`. For a given input,
+ the module samples a speed-up factor from ``factors`` uniformly at random and adjusts
+ the speed of the input by that factor.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ orig_freq (int): Original frequency of the signals in ``waveform``.
+ factors (Sequence[float]): Factors by which to adjust speed of input. Values greater than 1.0
+ compress ``waveform`` in time, whereas values less than 1.0 stretch ``waveform`` in time.
+
+ Example
+ >>> speed_perturb = SpeedPerturbation(16000, [0.9, 1.1, 1.0, 1.0, 1.0])
+ >>> # waveform speed will be adjusted by factor 0.9 with 20% probability,
+ >>> # 1.1 with 20% probability, and 1.0 (i.e. kept the same) with 60% probability.
+ >>> speed_perturbed_waveform = speed_perturb(waveform, lengths)
+ """
+
+ def __init__(self, orig_freq: int, factors: Sequence[float]) -> None:
+ super().__init__()
+
+ self.speeders = torch.nn.ModuleList([Speed(orig_freq=orig_freq, factor=factor) for factor in factors])
+
+ def forward(
+ self, waveform: torch.Tensor, lengths: Optional[torch.Tensor] = None
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ r"""
+ Args:
+ waveform (torch.Tensor): Input signals, with shape `(..., time)`.
+ lengths (torch.Tensor or None, optional): Valid lengths of signals in ``waveform``, with shape `(...)`.
+ If ``None``, all elements in ``waveform`` are treated as valid. (Default: ``None``)
+
+ Returns:
+ (torch.Tensor, torch.Tensor or None):
+ torch.Tensor
+ Speed-adjusted waveform, with shape `(..., new_time).`
+ torch.Tensor or None
+ If ``lengths`` is not ``None``, valid lengths of signals in speed-adjusted waveform,
+ with shape `(...)`; otherwise, ``None``.
+ """
+
+ idx = int(torch.randint(len(self.speeders), ()))
+ # NOTE: we do this because TorchScript doesn't allow for
+ # indexing ModuleList instances with non-literals.
+ for speeder_idx, speeder in enumerate(self.speeders):
+ if idx == speeder_idx:
+ return speeder(waveform, lengths)
+ raise RuntimeError("Speeder not found; execution should have never reached here.")
+
+
+class AddNoise(torch.nn.Module):
+ r"""Scales and adds noise to waveform per signal-to-noise ratio.
+ See :meth:`torchaudio.functional.add_noise` for more details.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+ """
+
+ def forward(
+ self, waveform: torch.Tensor, noise: torch.Tensor, snr: torch.Tensor, lengths: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ r"""
+ Args:
+ waveform (torch.Tensor): Input waveform, with shape `(..., L)`.
+ noise (torch.Tensor): Noise, with shape `(..., L)` (same shape as ``waveform``).
+ snr (torch.Tensor): Signal-to-noise ratios in dB, with shape `(...,)`.
+ lengths (torch.Tensor or None, optional): Valid lengths of signals in ``waveform`` and ``noise``,
+ with shape `(...,)` (leading dimensions must match those of ``waveform``). If ``None``, all
+ elements in ``waveform`` and ``noise`` are treated as valid. (Default: ``None``)
+
+ Returns:
+ torch.Tensor: Result of scaling and adding ``noise`` to ``waveform``, with shape `(..., L)`
+ (same shape as ``waveform``).
+ """
+ return F.add_noise(waveform, noise, snr, lengths)
+
+
+class Preemphasis(torch.nn.Module):
+ r"""Pre-emphasizes a waveform along its last dimension.
+ See :meth:`torchaudio.functional.preemphasis` for more details.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ coeff (float, optional): Pre-emphasis coefficient. Typically between 0.0 and 1.0.
+ (Default: 0.97)
+ """
+
+ def __init__(self, coeff: float = 0.97) -> None:
+ super().__init__()
+ self.coeff = coeff
+
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
+ r"""
+ Args:
+ waveform (torch.Tensor): Waveform, with shape `(..., N)`.
+
+ Returns:
+ torch.Tensor: Pre-emphasized waveform, with shape `(..., N)`.
+ """
+ return F.preemphasis(waveform, coeff=self.coeff)
+
+
+class Deemphasis(torch.nn.Module):
+ r"""De-emphasizes a waveform along its last dimension.
+ See :meth:`torchaudio.functional.deemphasis` for more details.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ coeff (float, optional): De-emphasis coefficient. Typically between 0.0 and 1.0.
+ (Default: 0.97)
+ """
+
+ def __init__(self, coeff: float = 0.97) -> None:
+ super().__init__()
+ self.coeff = coeff
+
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
+ r"""
+ Args:
+ waveform (torch.Tensor): Waveform, with shape `(..., N)`.
+
+ Returns:
+ torch.Tensor: De-emphasized waveform, with shape `(..., N)`.
+ """
+ return F.deemphasis(waveform, coeff=self.coeff)
diff --git a/MLPY/Lib/site-packages/torchaudio/utils/__init__.py b/MLPY/Lib/site-packages/torchaudio/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0aff17a50e2fb49176c884e5d4b11970ae76cbab
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/utils/__init__.py
@@ -0,0 +1,11 @@
+from torio.utils import ffmpeg_utils
+
+from . import sox_utils
+from .download import download_asset
+
+
+__all__ = [
+ "download_asset",
+ "sox_utils",
+ "ffmpeg_utils",
+]
diff --git a/MLPY/Lib/site-packages/torchaudio/utils/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96329bad617c1280eaa84e2c0afec84e098d6408
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/utils/__pycache__/download.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/utils/__pycache__/download.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fbcc7d931832e3c0eb473946d2768b3d1bae427d
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/utils/__pycache__/download.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/utils/__pycache__/ffmpeg_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/utils/__pycache__/ffmpeg_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70834785c266eb66e74c74142cdfd4c6826f96ae
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/utils/__pycache__/ffmpeg_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/utils/__pycache__/sox_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torchaudio/utils/__pycache__/sox_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ce5766281b495cf2d54492313186023a7bf72fc
Binary files /dev/null and b/MLPY/Lib/site-packages/torchaudio/utils/__pycache__/sox_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchaudio/utils/download.py b/MLPY/Lib/site-packages/torchaudio/utils/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f62c7ef1f56ba0ee25888e5fc14dcb2c665ba6a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/utils/download.py
@@ -0,0 +1,89 @@
+import hashlib
+import logging
+from os import PathLike
+from pathlib import Path
+from typing import Union
+
+import torch
+from torchaudio._internal import download_url_to_file
+
+_LG = logging.getLogger(__name__)
+
+
+def _get_local_path(key):
+ path = Path(torch.hub.get_dir()) / "torchaudio" / Path(key)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ return path
+
+
+def _download(key, path, progress):
+ url = f"https://download.pytorch.org/torchaudio/{key}"
+ download_url_to_file(url, path, progress=progress)
+
+
+def _get_hash(path, hash, chunk_size=1028):
+ m = hashlib.sha256()
+ with open(path, "rb") as file:
+ data = file.read(chunk_size)
+ while data:
+ m.update(data)
+ data = file.read(chunk_size)
+ return m.hexdigest()
+
+
+def download_asset(
+ key: str,
+ hash: str = "",
+ path: Union[str, PathLike] = "",
+ *,
+ progress: bool = True,
+) -> str:
+ """Download and store torchaudio assets to local file system.
+
+ If a file exists at the download path, then that path is returned with or without
+ hash validation.
+
+ Args:
+ key (str): The asset identifier.
+ hash (str, optional):
+ The value of SHA256 hash of the asset. If provided, it is used to verify
+ the downloaded / cached object. If not provided, then no hash validation
+ is performed. This means if a file exists at the download path, then the path
+ is returned as-is without verifying the identity of the file.
+ path (path-like object, optional):
+ By default, the downloaded asset is saved in a directory under
+ :py:func:`torch.hub.get_dir` and intermediate directories based on the given `key`
+ are created.
+ This argument can be used to overwrite the target location.
+ When this argument is provided, all the intermediate directories have to be
+ created beforehand.
+ progress (bool): Whether to show progress bar for downloading. Default: ``True``.
+
+ Note:
+ Currently the valid key values are the route on ``download.pytorch.org/torchaudio``,
+ but this is an implementation detail.
+
+ Returns:
+ str: The path to the asset on the local file system.
+ """
+ path = path or _get_local_path(key)
+
+ if path.exists():
+ _LG.info("The local file (%s) exists. Skipping the download.", path)
+ else:
+ _LG.info("Downloading %s to %s", key, path)
+ _download(key, path, progress=progress)
+
+ if hash:
+ _LG.info("Verifying the hash value.")
+ digest = _get_hash(path, hash)
+
+ if digest != hash:
+ raise ValueError(
+ f"The hash value of the downloaded file ({path}), '{digest}' does not match "
+ f"the provided hash value, '{hash}'."
+ )
+
+ _LG.info("Hash validated.")
+
+ return str(path)
diff --git a/MLPY/Lib/site-packages/torchaudio/utils/ffmpeg_utils.py b/MLPY/Lib/site-packages/torchaudio/utils/ffmpeg_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a4bb3c4e3b621ff7b48062d1c4d3374a4459c90
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/utils/ffmpeg_utils.py
@@ -0,0 +1,11 @@
+"""Module to change the configuration of FFmpeg libraries (such as libavformat).
+
+It affects functionalities in :py:mod:`torchaudio.io` (and indirectly :py:func:`torchaudio.load`).
+"""
+
+
+# This file is just for BC.
+def __getattr__(item):
+ from torio.utils import ffmpeg_utils
+
+ return getattr(ffmpeg_utils, item)
diff --git a/MLPY/Lib/site-packages/torchaudio/utils/sox_utils.py b/MLPY/Lib/site-packages/torchaudio/utils/sox_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8975b4216f54e3ece63483bf91b49f10385f5785
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/utils/sox_utils.py
@@ -0,0 +1,99 @@
+"""Module to change the configuration of libsox, which is used by I/O functions like
+:py:mod:`~torchaudio.backend.sox_io_backend` and :py:mod:`~torchaudio.sox_effects`.
+"""
+
+from typing import Dict, List
+
+import torchaudio
+
+sox_ext = torchaudio._extension.lazy_import_sox_ext()
+
+
+def set_seed(seed: int):
+ """Set libsox's PRNG
+
+ Args:
+ seed (int): seed value. valid range is int32.
+
+ See Also:
+ http://sox.sourceforge.net/sox.html
+ """
+ sox_ext.set_seed(seed)
+
+
+def set_verbosity(verbosity: int):
+ """Set libsox's verbosity
+
+ Args:
+ verbosity (int): Set verbosity level of libsox.
+
+ * ``1`` failure messages
+ * ``2`` warnings
+ * ``3`` details of processing
+ * ``4``-``6`` increasing levels of debug messages
+
+ See Also:
+ http://sox.sourceforge.net/sox.html
+ """
+ sox_ext.set_verbosity(verbosity)
+
+
+def set_buffer_size(buffer_size: int):
+ """Set buffer size for sox effect chain
+
+ Args:
+ buffer_size (int): Set the size in bytes of the buffers used for processing audio.
+
+ See Also:
+ http://sox.sourceforge.net/sox.html
+ """
+ sox_ext.set_buffer_size(buffer_size)
+
+
+def set_use_threads(use_threads: bool):
+ """Set multithread option for sox effect chain
+
+ Args:
+ use_threads (bool): When ``True``, enables ``libsox``'s parallel effects channels processing.
+ To use mutlithread, the underlying ``libsox`` has to be compiled with OpenMP support.
+
+ See Also:
+ http://sox.sourceforge.net/sox.html
+ """
+ sox_ext.set_use_threads(use_threads)
+
+
+def list_effects() -> Dict[str, str]:
+ """List the available sox effect names
+
+ Returns:
+ Dict[str, str]: Mapping from ``effect name`` to ``usage``
+ """
+ return dict(sox_ext.list_effects())
+
+
+def list_read_formats() -> List[str]:
+ """List the supported audio formats for read
+
+ Returns:
+ List[str]: List of supported audio formats
+ """
+ return sox_ext.list_read_formats()
+
+
+def list_write_formats() -> List[str]:
+ """List the supported audio formats for write
+
+ Returns:
+ List[str]: List of supported audio formats
+ """
+ return sox_ext.list_write_formats()
+
+
+def get_buffer_size() -> int:
+ """Get buffer size for sox effect chain
+
+ Returns:
+ int: size in bytes of buffers used for processing audio.
+ """
+ return sox_ext.get_buffer_size()
diff --git a/MLPY/Lib/site-packages/torchaudio/version.py b/MLPY/Lib/site-packages/torchaudio/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..986b1484a35bfb5827cb008f95365b37742d5b01
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchaudio/version.py
@@ -0,0 +1,2 @@
+__version__ = '2.3.1+cpu'
+git_version = '3edcf69e78a3c9a3077a11159861422440ec7d4a'
diff --git a/MLPY/Lib/site-packages/torchgen/__init__.py b/MLPY/Lib/site-packages/torchgen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8b61af2c4b58ff14ab7e3b24bf22e8ec6a95da0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/__init__.py
@@ -0,0 +1,10 @@
+"""torchgen
+
+This module contains codegeneration utilities for PyTorch. It is used to
+build PyTorch from source, but may also be used for out-of-tree projects
+that extend PyTorch.
+
+Note well that we provide no BC guarantees for torchgen. If you're interested
+in using torchgen and want the PyTorch team to be aware, please reach out
+on GitHub.
+"""
diff --git a/MLPY/Lib/site-packages/torchgen/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0f82e03da7a1bbcbd6d63d059c3eee7a5f7a15f3
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/__pycache__/code_template.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/__pycache__/code_template.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7bae7cf86389ef923393e40238a8b319b57782f8
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/__pycache__/code_template.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/__pycache__/context.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/__pycache__/context.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca68294e00b4e9596b0326f739ef247138b4da35
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/__pycache__/context.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/__pycache__/gen.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/__pycache__/gen.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b35b2c4d863d0a8eaaad56c19df4d890a73e1c58
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/__pycache__/gen.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79714eccfd7dd56ef5541fd77992c97383f71923
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9135cfa518f47e2ec2ca17a880d0bf35def33a59
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/__pycache__/gen_executorch.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/__pycache__/gen_executorch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3db6e22165c607a059580ce12d0f32867d96b3f8
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/__pycache__/gen_executorch.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f828ff7edf70e3dc964624f8538f630efc6e052e
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..789762bc114d2ab9481070536369870d0fbf1167
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd62533b71869a1539302814d8d63e36f3572a58
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/__pycache__/local.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/__pycache__/local.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd6fee321c21af0277b1e205294c9a27f5d398fe
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/__pycache__/local.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/__pycache__/model.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/__pycache__/model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..04561eb0a4cdc41f4ac86738f7f384cf7a221de9
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/__pycache__/model.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/__pycache__/native_function_generation.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/__pycache__/native_function_generation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d0d7af416babb98d9d96b66a4df9e3d6834a230
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/__pycache__/native_function_generation.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d01da46b5bc71625cc115f3a353042efa3f153d8
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/__pycache__/yaml_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/__pycache__/yaml_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b5d843276b84b07e1367660c12ee292fe05f2e1
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/__pycache__/yaml_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/__init__.py b/MLPY/Lib/site-packages/torchgen/api/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torchgen/api/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f717f5c6679af213a9cbc53097dbcf3eb16d125
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/__pycache__/autograd.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/__pycache__/autograd.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec72dc2d4644cfdf082407fba2cd49257d71797a
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/__pycache__/autograd.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/__pycache__/cpp.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/__pycache__/cpp.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee95402324f6540e9196b2d7e11f62ef6080a04f
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/__pycache__/cpp.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/__pycache__/dispatcher.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/__pycache__/dispatcher.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b870cc7919cf983af356ccebaa3389dbbc0d6a0
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/__pycache__/dispatcher.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/__pycache__/functionalization.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/__pycache__/functionalization.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c6df5d09a406a3b746a33eb4b6d8f866ad2adf6c
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/__pycache__/functionalization.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/__pycache__/lazy.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/__pycache__/lazy.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a4377be4af138d7df7e729a92e04eeaf9295691
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/__pycache__/lazy.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/__pycache__/meta.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/__pycache__/meta.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45724da7a063750d30619224af6298acc14c9739
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/__pycache__/meta.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/__pycache__/native.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/__pycache__/native.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5217ea5b59ca98416ed706f4e5a2fee2babb3a53
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/__pycache__/native.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/__pycache__/python.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/__pycache__/python.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9d874250872dfe19afd2e0228783e408a6ec357
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/__pycache__/python.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/__pycache__/structured.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/__pycache__/structured.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..850b3aa893efe116e9c7f1b1e8dc0ecc50139489
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/__pycache__/structured.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/__pycache__/translate.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/__pycache__/translate.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3adbd825adf8a3cf1e40802ee3451cac3bf54aa3
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/__pycache__/translate.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/__pycache__/ufunc.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/__pycache__/ufunc.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9120ae85398231d1f44d46af89746c4813590d34
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/__pycache__/ufunc.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/__pycache__/unboxing.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/__pycache__/unboxing.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c36d391bf48a4153b70313e757f7028a0f7fb1ae
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/__pycache__/unboxing.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/autograd.py b/MLPY/Lib/site-packages/torchgen/api/autograd.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ac5011c90a765261c917e94ddb20ca443536f17
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/autograd.py
@@ -0,0 +1,853 @@
+import re
+from dataclasses import dataclass
+from typing import cast, Dict, List, Match, Optional, Sequence, Set, Tuple
+
+from torchgen import local
+
+from torchgen.api import cpp
+from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
+from torchgen.model import (
+ BaseTy,
+ BaseType,
+ FunctionSchema,
+ ListType,
+ NativeFunction,
+ NativeFunctionsViewGroup,
+ SchemaKind,
+ Type,
+)
+from torchgen.utils import IDENT_REGEX
+
+
+# Represents a saved attribute involved in backward calculation.
+# Note that it can be a derived property of an input argument, e.g.:
+# we could save `other.scalar_type()` instead of the entire `other` tensor.
+@dataclass(frozen=True)
+class SavedAttribute:
+ # The NamedCType holds the updated name and cpp type of the attribute
+ # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type`
+ nctype: NamedCType
+
+ # The expression to read the derived property at save time, e.g.:
+ # `other.scalar_type()`.
+ expr: str
+
+
+# Represents a backward formula that calculates derivatives for one
+# or more tensors.
+@dataclass(frozen=True)
+class Derivative:
+ # The formula string (legit C++ expression).
+ # Note that expressions against input arguments have been replaced with the
+ # corresponding saved attributes.
+ # E.g.:
+ # raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
+ # here: `mul_tensor_backward(grad, self, other_scalar_type)`
+ formula: str
+
+ # The formula string before input argument replacement
+ original_formula: str
+
+ # Names of the arguments for which this formula calculates derivatives.
+ var_names: Tuple[str, ...]
+
+ # Saved inputs that are referenced by the formula.
+ saved_inputs: Tuple[SavedAttribute, ...]
+
+ # Saved outputs that are referenced by the formula.
+ saved_outputs: Tuple[SavedAttribute, ...]
+
+ # Gradients that are referenced by name in the formula.
+ named_gradients: Set[str]
+
+
+# Represents a forward formula that calculates forward derivatives
+# for one tensor.
+@dataclass(frozen=True)
+class ForwardDerivative:
+ # The formula string (legit C++ expression).
+ # Note that special keywords such as "linear" or "element_wise" have been
+ # replaced by the automatically generated formula.
+ formula: str
+
+ # Name of the output arguments for which this formula calculates forward
+ # derivatives
+ var_names: Tuple[str, ...]
+
+ # Type of the output arguments for which this formula calculates forward
+ # derivatives
+ var_types: Tuple[Type, ...]
+
+ # Inputs for which the forward derivatives are required for this formula
+ required_inputs_fw_grad: Optional[Tuple[str, ...]]
+
+ # Inputs for which the primal is required for this formula
+ required_inputs_primal: Optional[Tuple[str, ...]]
+
+ # Flag to specify if this formula requires the original value of self
+ # This is only used by inplace operations
+ required_original_self_value: bool
+
+ # If this formula is specified in derivatives.yaml or if we are re-using the
+ # out of place formula for inplace
+ is_reusing_outplace_formula: bool
+
+
+# Represents differentiability info for a NativeFunction.
+@dataclass(frozen=True)
+class DifferentiabilityInfo:
+ # The base name read from derivatives.yaml.
+ name: str
+
+ # The matching native function.
+ #
+ # There can be multiple NativeFunction having the same base name:
+ # - different overloads with different types of input arguments;
+ # - in-place/out/functional variants of the same function;
+ #
+ # We first use the schema string (under the 'name' key) in derivatives.yaml
+ # to find the NativeFunction having the same schema string.
+ # Then we find the in-place/out/functional variants of the matching function.
+ # Among these variants, we choose the one having the same name as the
+ # derivatives.yaml entry. If there is no exact match, then we choose the
+ # in-place variant.
+ # TODO: maybe the logic to search for all variants is no longer necessary?
+ func: NativeFunction
+
+ # The name of the generated autograd function.
+ # It's set only if we will calculate a derivative, i.e.
+ # 'args_with_derivatives' is not empty.
+ op: Optional[str]
+
+ # The derivatives formulae for this function.
+ # Note that the length of this sequence is the number of differentiable inputs
+ derivatives: Sequence[Derivative]
+
+ # The forward derivatives formulae for this function.
+ # Note that the length of this sequence is the number of differentiable outputs
+ forward_derivatives: Sequence[ForwardDerivative]
+
+ # The union of 'saved_inputs' of all 'derivatives'.
+ all_saved_inputs: Sequence[SavedAttribute]
+
+ # The union of 'saved_outputs' of all 'derivatives'.
+ all_saved_outputs: Sequence[SavedAttribute]
+
+ # All named gradients that are available for use, in the same
+ # order as in the grads vector.
+ available_named_gradients: Sequence[str]
+
+ # The named gradients that are used in any of the derivatives.
+ # Invariant: all(name in available_named_gradients for name in used_named_gradients)
+ used_named_gradients: Set[str]
+
+ # The function's input arguments for which it calculates derivatives.
+ # It's the union of 'var_names' of all 'derivatives', sorted by the
+ # argument order in the function schema.
+ args_with_derivatives: Sequence[Binding]
+
+ # Names of arguments whose derivative formula is 'non_differentiable'.
+ non_differentiable_arg_names: Sequence[str]
+
+ # Raw data read from derivatives.yaml.
+ output_differentiability: Optional[List[bool]]
+
+ # output_differentiability in derivatives.yaml can be a list of
+ # conditions that express if the output is differentiable. In this case,
+ # the number of conditions must match the number of outputs
+ # (NB: we only support one condition right now).
+ # output_differentiability gets populated with True for each condition,
+ # while output_differentiability_conditions gets populated with the conditions
+ output_differentiability_conditions: Optional[List[str]]
+
+ @property
+ def has_derivatives(self) -> bool:
+ return len(self.args_with_derivatives) > 0
+
+ # Generates a new DifferentiabilityInfo using the exact same set of derivative information,
+ # but with a new operator name.
+ # This is used when generating "copy" variants of view ops,
+ # which are able to use the exact same derivative formula as the original view op
+ # See Note [Codegen'd {view}_copy Operators]
+ def create_view_copy_from_view_derivative(
+ self, g: NativeFunctionsViewGroup
+ ) -> Optional["DifferentiabilityInfo"]:
+ if g.view_copy is None:
+ return None
+ f = g.view_copy
+
+ name_split_by_period = self.name.split(".", maxsplit=2)
+ # Append a "_copy" to the base name of the operator (but keep the overload name the same)
+ view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join(
+ name_split_by_period[1:]
+ )
+ view_copy_op_name = None if self.op is None else f"{self.op}_copy"
+
+ return DifferentiabilityInfo(
+ # Use the "_copy" version of name/func/op
+ name=view_copy_name,
+ func=f,
+ op=view_copy_op_name,
+ # But keep all derivative info the same
+ derivatives=self.derivatives,
+ forward_derivatives=self.forward_derivatives,
+ all_saved_inputs=self.all_saved_inputs,
+ all_saved_outputs=self.all_saved_outputs,
+ available_named_gradients=self.available_named_gradients,
+ used_named_gradients=self.used_named_gradients,
+ args_with_derivatives=self.args_with_derivatives,
+ non_differentiable_arg_names=self.non_differentiable_arg_names,
+ output_differentiability=self.output_differentiability,
+ output_differentiability_conditions=self.output_differentiability_conditions,
+ )
+
+
+def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool:
+ if info is None:
+ return False
+ for derivative in info.derivatives:
+ formula = derivative.formula
+ if re.search(IDENT_REGEX.format(ident), formula):
+ return True
+ return False
+
+
+def uses_retain_variables(info: Optional[DifferentiabilityInfo]) -> bool:
+ return uses_ident(info, "retain_variables")
+
+
+def uses_single_grad(info: Optional[DifferentiabilityInfo]) -> bool:
+ return uses_ident(info, "grad")
+
+
+# Represents a differentiable `Argument`.
+# How is it different from the `Argument` type?
+# - It's processed Arguments which are differentiable and only used in the
+# context of the autograd codegen;
+# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
+@dataclass(frozen=True)
+class DifferentiableInput:
+ name: str
+ type: Type
+
+ # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
+ cpp_type: str
+
+
+# Represents a differentiable `Return`.
+# How it it different from the `Return` type?
+# - The name in `Return` is optional. Here it is always populated using the same
+# `cpp.return_names()` method.
+# TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
+# - It's processed Returns which are differentiable, in compliance with the
+# `output_differentiability` field defined in derivatives.yaml (if specified),
+# and are only used in the context of the autograd codegen;
+@dataclass(frozen=True)
+class DifferentiableOutput:
+ name: str
+ type: Type
+
+ # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
+ cpp_type: str
+
+
+@dataclass(frozen=True)
+class NativeFunctionWithDifferentiabilityInfo:
+ func: NativeFunction
+ info: Optional[Dict[str, DifferentiabilityInfo]]
+ fw_derivatives: Optional[Dict[str, Sequence[ForwardDerivative]]]
+
+
+# TODO: Update comment below since it is out of date.
+def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
+ """How are we going to call the underlying implementation of a
+ declaration? There are two strategies:
+ - use_derived: we want to call the implementation on CPUDoubleType
+ (or a similar, derived Type instance). Because these derived
+ instances deal in Tensors, not Variables (it's a completely different
+ object, so it doesn't dispatch back to VariableType), code on
+ this dispatch path needs to wrap/unwrap tensors. If the
+ derived implementation takes and returns tensors, the
+ implementation is usually differentiable (although we also use
+ the derived dispatch path for non-differentiable functions
+ that we still want to dispatch on the derived Type instance;
+ e.g., size())
+ - use_type: we want to call the implementation on Type, because
+ it is implemented concretely, and the functions it invokes will
+ get dispatched back to VariableType (which will ensure that they
+ are differentiable.)
+ """
+ # fn is derived as long as any of its per-key differentiability infos
+ # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType
+ # and ADInplaceOrViewType. We want to generate these functions as long as a
+ # derivative is defined for ANY dispatch key.
+ if fn.func.is_abstract or (
+ fn.info is not None and any(info.has_derivatives for info in fn.info.values())
+ ):
+ # If the function is abstract (not implemented on at::Type), we must
+ # call the implementation on the derived type with unpacked tensors.
+
+ # If the function has a derivative specified and is concrete, we could
+ # call either implementation. We prefer the calling the derived
+ # type's implementation with unpacked tensors because it is more
+ # performant in some cases: any internal calls to other ATen functions
+ # won't have the history tracked.
+
+ # If the function has a type dispatched argument (i.e. is a factory),
+ # we prefer calling the derived type's implementation both because it is
+ # more performant and to ensure factory functions return tensors with _version
+ # of 0 (probably not strictly necessary, but nice to have to keeps versions simple
+ # to understand.
+
+ return "use_derived"
+ else:
+ # If the function is concrete (we don't have to override it) and we
+ # didn't declare it in derivatives.yaml, we'll assume that it is
+ # actually implemented out of differentiable functions. (This
+ # assumption might not hold, but then you'll see gradcheck fail.)
+ return "use_type"
+
+
+def is_foreach_func(f: NativeFunction) -> bool:
+ return f.func.name.name.base.startswith("_foreach_")
+
+
+# note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind
+# is functional for their backward derivatives (and forward derivatives in the future), i.e.,
+# they would find such one in `functional_info_by_signature`. There however are some exceptions:
+_foreach_with_inplace_ref = {"_foreach_zero_"}
+_foreach_with_tensor_overload = {
+ "_foreach_add.Tensor",
+ "_foreach_mul.Tensor",
+ "_foreach_div.Tensor",
+}
+
+
+# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function
+# reference to generate derivatives.
+def is_reference_for_foreach(
+ f: NativeFunction,
+ function_schema: FunctionSchema,
+) -> bool:
+ return (
+ f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base
+ and (
+ not function_schema.name.name.inplace
+ or str(f.func.name) in _foreach_with_inplace_ref
+ )
+ and all(
+ ref_arg.type in (arg.type, getattr(arg.type, "elem", None))
+ for arg, ref_arg in zip(
+ f.func.arguments.flat_non_out,
+ function_schema.arguments.flat_non_out,
+ )
+ )
+ )
+
+
+# TODO(crcrpar): Avoid hard coding "Default" ideally.
+def gen_foreach_derivativeinfo(
+ foreach_function: NativeFunction,
+ functional_info_by_signature: Dict[
+ FunctionSchema, Dict[str, DifferentiabilityInfo]
+ ],
+ non_functional_info_by_signature: Dict[
+ FunctionSchema, Dict[str, DifferentiabilityInfo]
+ ],
+ dispatch_key: str = "Default",
+) -> Tuple[Optional[DifferentiabilityInfo], bool]:
+ """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
+
+ The second return value indicates whether the info is generated in this function.
+ """
+ ref_diff_info: Optional[DifferentiabilityInfo] = None
+
+ for function_schema, diff_info in functional_info_by_signature.items():
+ if not is_reference_for_foreach(foreach_function, function_schema):
+ continue
+ ref_diff_info = diff_info[dispatch_key]
+ if ref_diff_info is not None:
+ break
+ # note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature
+ # while the info of `zero_` is in non_functional_info_by_signature
+ if (
+ ref_diff_info is None
+ and foreach_function.func.kind() == SchemaKind.inplace
+ and str(foreach_function.func.name) in _foreach_with_inplace_ref
+ ):
+ for function_schema, diff_info in non_functional_info_by_signature.items():
+ if not is_reference_for_foreach(foreach_function, function_schema):
+ continue
+ ref_diff_info = diff_info[dispatch_key]
+ if ref_diff_info is not None:
+ break
+ if ref_diff_info is None:
+ return None, False
+
+ # non out-place uses the existing Derivative.
+ if foreach_function.func.kind() == SchemaKind.inplace:
+ return ref_diff_info, False
+
+ map_refarg2foreacharg, map_name2arg = {}, {}
+ for i, (arg, ref_arg) in enumerate(
+ zip(
+ foreach_function.func.arguments.flat_non_out,
+ function_schema.arguments.flat_non_out,
+ )
+ ):
+ map_refarg2foreacharg[ref_arg.name] = arg.name
+ map_name2arg[arg.name] = arg
+
+ all_saved_inputs, all_saved_outputs, all_var_names = [], [], []
+ modified_derivative_formulas = []
+ for i, derivative in enumerate(ref_diff_info.derivatives):
+ modified_formula = derivative.formula.replace("grad", "grads[i]").replace(
+ "result", "result[i]"
+ )
+ saved_inputs, saved_outputs = [], []
+ # note(crcrpar): This context seems necessary to call `cpp.argument_type`
+ with local.parametrize(
+ use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
+ use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
+ ):
+ for ref_input in derivative.saved_inputs:
+ ref_input_jit_name = ref_input.expr.split(".")[0]
+ mapped_name = map_refarg2foreacharg[ref_input_jit_name]
+ if isinstance(map_name2arg[mapped_name].type, ListType):
+ mapped_expr = mapped_name + "[i]"
+ else:
+ mapped_expr = mapped_name
+ new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr)
+ modified_formula = modified_formula.replace(
+ cast(str, ref_input.nctype.name), new_expr
+ )
+
+ nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name)
+ canonical_nctype = NamedCType(
+ nctype.name, nctype.type.remove_const_ref()
+ )
+ saved_inputs.append(
+ SavedAttribute(nctype=canonical_nctype, expr=mapped_name)
+ )
+ for ref_output in derivative.saved_outputs:
+ if ref_output.nctype.name == "result":
+ saved_outputs.append(
+ SavedAttribute(
+ nctype=NamedCType(
+ name="result", type=BaseCType(tensorListT)
+ ),
+ expr="result",
+ )
+ )
+ else:
+ raise RuntimeError("")
+ var_names = [map_refarg2foreacharg[var] for var in derivative.var_names]
+ all_var_names.extend(var_names)
+ all_saved_inputs.extend(saved_inputs)
+ all_saved_outputs.extend(saved_outputs)
+ modified_derivative = Derivative(
+ formula=modified_formula,
+ original_formula=derivative.formula,
+ var_names=tuple(var_names),
+ saved_inputs=tuple(saved_inputs),
+ saved_outputs=tuple(saved_outputs),
+ named_gradients=set(),
+ )
+ modified_derivative_formulas.append(modified_derivative)
+
+ with local.parametrize(
+ use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
+ use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
+ ):
+ args_with_derivatives = [
+ Binding(
+ name=arg.name,
+ nctype=cpp.argument_type(arg, binds=arg.name),
+ argument=arg,
+ default=None,
+ )
+ for arg in foreach_function.func.arguments.flat_non_out
+ if arg.name in all_var_names
+ ]
+
+ forward_derivatives: List[ForwardDerivative] = []
+ fw_derivative: ForwardDerivative
+ for fw_derivative in ref_diff_info.forward_derivatives:
+ var_names: List[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
+ var_types: List[Type] = list(fw_derivative.var_types)
+ required_inputs_fw_grad: List[str] = []
+ required_inputs_primal: List[str] = []
+ if fw_derivative.required_inputs_fw_grad is not None:
+ required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
+ if fw_derivative.required_inputs_primal:
+ required_inputs_primal = list(fw_derivative.required_inputs_primal)
+ modified_formula = fw_derivative.formula
+
+ # Foreach's result is TensorList
+ if "result" in modified_formula:
+ modified_formula = fw_derivative.formula.replace("result", "result[i]")
+
+ for foreach_arg, ref_arg in zip(
+ foreach_function.func.arguments.flat_non_out,
+ ref_diff_info.func.func.arguments.flat_non_out,
+ ):
+ # Modify reference forward formula
+ if (
+ isinstance(foreach_arg.type, ListType)
+ and not foreach_arg.type.is_tensor_like()
+ ):
+ # Assuming ScalarList
+ modified_formula = modified_formula.replace(
+ ref_arg.name, foreach_arg.name + "[i]"
+ )
+ elif foreach_arg.type.is_tensor_like():
+ # Assuming TensorList / Tensor
+ # assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}"
+ assert isinstance(foreach_arg.type, ListType) or (
+ foreach_arg.type == BaseType(BaseTy.Tensor)
+ and str(foreach_function.func.name) in _foreach_with_tensor_overload
+ ), f"{foreach_function.func.name}, {foreach_arg.type}"
+ for suffix in ("_p", "_t"):
+ curr_expr = ref_arg.name + suffix
+ if curr_expr in modified_formula:
+ new_expr = foreach_arg.name + suffix
+ modified_formula = modified_formula.replace(curr_expr, new_expr)
+ else:
+ # Assuming Scalar
+ if foreach_arg.name != ref_arg.name:
+ modified_formula = modified_formula.replace(
+ ref_arg.name, foreach_arg.name
+ )
+
+ # note(crcrpar): there should exist a cooler way...
+ for i, name in enumerate(var_names):
+ if name == ref_arg.name:
+ var_names[i] = foreach_arg.name
+ var_types[i] = foreach_arg.type
+ for i, name in enumerate(required_inputs_fw_grad):
+ if name == ref_arg.name:
+ required_inputs_fw_grad[i] = foreach_arg.name
+ for i, name in enumerate(required_inputs_primal):
+ if name == ref_arg.name:
+ required_inputs_primal[i] = foreach_arg.name
+ forward_derivatives.append(
+ ForwardDerivative(
+ formula=modified_formula,
+ var_names=tuple(var_names),
+ var_types=tuple(var_types),
+ required_inputs_fw_grad=tuple(required_inputs_fw_grad),
+ required_inputs_primal=tuple(required_inputs_primal),
+ required_original_self_value=fw_derivative.required_original_self_value,
+ is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula,
+ )
+ )
+
+ return (
+ DifferentiabilityInfo(
+ name=foreach_function.func.name.name.base,
+ func=foreach_function,
+ op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}",
+ derivatives=modified_derivative_formulas,
+ forward_derivatives=forward_derivatives,
+ all_saved_inputs=tuple(set(all_saved_inputs)),
+ all_saved_outputs=tuple(set(all_saved_outputs)),
+ available_named_gradients=(),
+ used_named_gradients=set(),
+ args_with_derivatives=args_with_derivatives,
+ non_differentiable_arg_names=[],
+ output_differentiability=None,
+ output_differentiability_conditions=None,
+ ),
+ True,
+ )
+
+
+def match_differentiability_info(
+ native_functions: List[NativeFunction],
+ differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
+) -> List[NativeFunctionWithDifferentiabilityInfo]:
+ """Sets the "derivative" key on declarations to matching autograd function
+ In-place functions will use the out-of-place derivative definition if there
+ is no in-place specific derivative.
+ """
+
+ functional_info_by_signature = {
+ schema.signature(strip_default=True): info_dict
+ for schema, info_dict in differentiability_infos.items()
+ if schema.kind() == SchemaKind.functional
+ }
+ non_functional_info_by_signature = {
+ schema.signature(strip_default=True): info_dict
+ for schema, info_dict in differentiability_infos.items()
+ if schema.kind() != SchemaKind.functional
+ }
+
+ def find_info(
+ f: NativeFunction,
+ ) -> Tuple[Optional[Dict[str, DifferentiabilityInfo]], bool]:
+ # Don't bother matching info to generated out= variants
+ if "generated" in f.tags and f.func.kind() == SchemaKind.out:
+ return None, False
+
+ # (1) Check for an exact match
+ if f.func in differentiability_infos:
+ return differentiability_infos[f.func], True
+
+ # (2) If no exact match, check if the out-of-place variant
+ # of this operator has a match.
+ # i.e mul() for mul_() or mul_out()
+ # note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing
+ # native functions instead of the out-place counterparts.
+ f_sig = f.func.signature(strip_default=True)
+ if f_sig in functional_info_by_signature and not is_foreach_func(f):
+ return functional_info_by_signature[f_sig], False
+
+ # (3) Some operators have a derivative explicitly defined for the mutable
+ # variant, but get a code-generated out-of-place variant which does *not*
+ # come with a derivative formula.
+ # For the generated out-of-place variant, use the mutable variant's formula
+ # if it exists.
+ if "generated" in f.tags and f_sig in non_functional_info_by_signature:
+ info_dict = non_functional_info_by_signature[f_sig]
+ # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
+ assert not any(
+ any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs)
+ for info in info_dict.values()
+ ), f"""\
+Attempted to convert a derivative formula for a mutable operator
+ to be used by automatically by its functional variant ("{str(f.func)}").
+ this is not currently supported (we'd need to fix up the formula in the codegen)."""
+ return info_dict, False
+
+ # (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml`
+ if is_foreach_func(f):
+ assert f.func not in differentiability_infos
+ diff_info, is_generated = gen_foreach_derivativeinfo(
+ f,
+ functional_info_by_signature,
+ non_functional_info_by_signature,
+ )
+ if diff_info is None:
+ return None, False
+ # TODO(crcrpar): Avoid hard coding "Default" ideally.
+ diff_info_dict = {"Default": diff_info}
+ if is_generated:
+ differentiability_infos[f.func] = diff_info_dict
+ functional_info_by_signature[f.func] = diff_info_dict
+ return diff_info_dict, is_generated
+
+ return None, False
+
+ result: List[NativeFunctionWithDifferentiabilityInfo] = []
+ for f in native_functions:
+ info_dict, is_exact_match = find_info(f)
+
+ # Currently, the '.strides()' to 'strides_or_error' replacement does not support
+ # 'self' derivatives of an inplace function, so we must check for this case.
+ if f.func.kind() == SchemaKind.inplace and (info_dict is not None):
+ for info in info_dict.values():
+ for derivative in info.derivatives:
+ if "self" in derivative.var_names:
+ for saved_input in derivative.saved_inputs:
+ assert "strides_or_error" not in saved_input.expr, (
+ "Calling '.strides()' in the 'self' derivative formula of an "
+ f"in-place function is not supported: {f.func}"
+ )
+
+ if not info_dict:
+ result.append(
+ NativeFunctionWithDifferentiabilityInfo(
+ func=f, info=None, fw_derivatives=None
+ )
+ )
+ continue
+
+ fw_derivative_dict: Dict[str, Sequence[ForwardDerivative]] = {}
+ for key, info in info_dict.items():
+ if not info.forward_derivatives:
+ fw_derivative_dict[key] = []
+ continue
+
+ forward_derivatives = info.forward_derivatives
+
+ # For functions that have a single def for out-of-place and inplace (like abs())
+ if f.func.kind() == SchemaKind.inplace:
+ # For inplace functions there is a little bit of work to do:
+ # 1) Validate the formula and make sure the input that is modified in not used:
+ # - If there is a formula for the inplace variant of the function (is_exact_match == True) then
+ # we make sure that the original value of the input that is being modified inplace (self_p) is
+ # not used in the formula. Note that the formula can use "original_self_p" here and that would
+ # trigger a clone of the original input.
+ # - If we are re-using the out of place formula (is_exact_match == False) then we replace every
+ # occurrence of self_p and self_t by original_self_p and original_self_t. These will be
+ # populated by cloned version of the original input (either the clone done by the backward AD
+ # logic if self is also used in a backward formula or a special clone that we add).
+ # 2) At this point, there cannot be a self_p in the formula.
+ # 3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is
+ # simply called self (as it is modified inplace).
+ # 4) Update the required primals data in case it used to contain "result" but should now contain
+ # "self"
+ # 5) If it is not an exact match, the user formula is not modifying the existing forward grad
+ # inplace as it should. So add some code that makes sure that we do so if the forward grad
+ # already exists.
+
+ assert (
+ len(info.forward_derivatives) == 1
+ ) # Only single output inplace should exist
+ fw_info = info.forward_derivatives[0]
+ formula = fw_info.formula
+
+ def replace_self_with_original_self(formula: str, postfix: str) -> str:
+ def repl(m: Match[str]) -> str:
+ return f"{m.group(1)}original_self{postfix}{m.group(2)}"
+
+ return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
+
+ if re.search(IDENT_REGEX.format("self_p"), formula):
+ if is_exact_match:
+ # For manually defined formulas, don't allow the original value to be used
+ raise RuntimeError(
+ f'The formula for "{f.func.name}" is using the original value of self '
+ "that is being modified inplace. This would lead to wrong forward gradients. "
+ 'Please use "result" in the formula only.'
+ )
+ else:
+ # When the original formula is out of place, we save a clone of the primal
+ # value to be able to access this value if needed
+ # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t"
+ formula = replace_self_with_original_self(formula, "_p")
+ formula = replace_self_with_original_self(formula, "_t")
+
+ # replace "result" from the formula by "self_p"
+ def repl(m: Match[str]) -> str:
+ return f"{m.group(1)}self_p{m.group(2)}"
+
+ formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
+
+ required_primals = fw_info.required_inputs_primal
+ if re.search(IDENT_REGEX.format("self_p"), formula):
+ required_primals = (
+ required_primals + ("self",) if required_primals else ("self",)
+ )
+
+ if not is_exact_match:
+ # NOTE [In-place forward AD formula Optimization]
+ #
+ # This optimization transforms the formula to directly do inplace, i.e.
+ # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met:
+ #
+ # 1) the formula satisfies the pattern: "self_t.op(*args)"
+ # 2) "op" in (1) needs to be the same as the op the derivative is for
+ #
+ # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2)
+ # If there is a need, we can relax (2) to allow any op that has an in-place variant
+ is_single_method_on_self_t = False
+ directly_do_inplace = False
+ op_name: Optional[str] = None
+ between_parens: Optional[str] = None
+ match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
+ if match:
+ op_name, between_parens = match.group(1), match.group(2)
+
+ # We want to...
+ # Match: self_t.op1(other_p.op2(arg))
+ # Avoid: self_t.op1(args) + self_t.op2(args)
+ # Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args)
+ def check_parens_nest_level_gt_zero(s: str) -> bool:
+ level = 1
+ for ch in s:
+ if ch == ")":
+ level -= 1
+ if level == 0:
+ return False
+ if ch == "(":
+ level += 1
+ return True
+
+ is_single_method_on_self_t = check_parens_nest_level_gt_zero(
+ between_parens
+ )
+ directly_do_inplace = (
+ is_single_method_on_self_t and op_name == info.name
+ )
+
+ if directly_do_inplace:
+ assert op_name is not None
+ assert between_parens is not None
+ formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}"
+ else:
+ # Make sure that the forward grad is modified inplace when the original formula
+ # is out of place
+ formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}"
+
+ required_original_self_value = bool(
+ re.search(IDENT_REGEX.format("original_self_p"), formula)
+ ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula))
+
+ forward_derivatives = [
+ ForwardDerivative(
+ formula=formula,
+ var_names=("self",),
+ var_types=fw_info.var_types,
+ required_inputs_fw_grad=fw_info.required_inputs_fw_grad,
+ required_inputs_primal=required_primals,
+ required_original_self_value=required_original_self_value,
+ is_reusing_outplace_formula=not is_exact_match,
+ ),
+ ]
+
+ fw_derivative_dict[key] = forward_derivatives
+
+ result.append(
+ NativeFunctionWithDifferentiabilityInfo(
+ func=f, info=info_dict, fw_derivatives=fw_derivative_dict
+ )
+ )
+
+ return result
+
+
+def is_differentiable(
+ name: str, type: Type, info: Optional[DifferentiabilityInfo]
+) -> bool:
+ return type.is_tensor_like() and (
+ info is None or name not in info.non_differentiable_arg_names
+ )
+
+
+def gen_differentiable_outputs(
+ fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
+) -> List[DifferentiableOutput]:
+ f = fn.func
+ info = fn.info[key] if fn.info else None
+ outputs: List[DifferentiableOutput] = [
+ DifferentiableOutput(
+ name=name,
+ type=ret.type,
+ cpp_type=cpp.return_type(ret, symint=True).cpp_type(),
+ )
+ for name, ret in zip(cpp.return_names(f), f.func.returns)
+ ]
+ output_differentiability = info.output_differentiability if info else None
+ if output_differentiability is not None:
+ if len(output_differentiability) != len(outputs):
+ raise RuntimeError(
+ f"The length of output_differentiability ({len(output_differentiability)}), "
+ f"does not match the number of outputs ({len(outputs)})."
+ )
+ differentiable_outputs: List[DifferentiableOutput] = []
+ if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
+ raise RuntimeError(
+ "output_differentiability=False for inplace operation (version_counter won't get updated)"
+ )
+ for differentiable, output in zip(output_differentiability, outputs):
+ if differentiable:
+ differentiable_outputs.append(output)
+ return differentiable_outputs
+ candidate_differentiable_outputs = list(
+ filter(lambda r: is_differentiable(r.name, r.type, info), outputs)
+ )
+ if uses_single_grad(info):
+ return candidate_differentiable_outputs[:1]
+ else:
+ return candidate_differentiable_outputs
diff --git a/MLPY/Lib/site-packages/torchgen/api/cpp.py b/MLPY/Lib/site-packages/torchgen/api/cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccb3f28170295b883e2b36ddadf5e638d3a6ba8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/cpp.py
@@ -0,0 +1,467 @@
+from typing import List, Optional, Sequence, Set, Union
+
+from torchgen import local
+from torchgen.api.types import (
+ ArgName,
+ ArrayCType,
+ ArrayRefCType,
+ BaseCType,
+ BaseTypeToCppMapping,
+ Binding,
+ boolT,
+ ConstRefCType,
+ CType,
+ dimnameListT,
+ intArrayRefT,
+ iTensorListRefT,
+ ListCType,
+ longT,
+ MutRefCType,
+ NamedCType,
+ OptionalCType,
+ optionalIntArrayRefT,
+ optionalSymIntArrayRefT,
+ scalarT,
+ SpecialArgName,
+ symIntArrayRefT,
+ SymIntT,
+ tensorListT,
+ tensorOptionsT,
+ tensorT,
+ TupleCType,
+ VectorCType,
+ voidT,
+)
+from torchgen.model import (
+ Argument,
+ Arguments,
+ BaseTy,
+ BaseType,
+ FunctionSchema,
+ ListType,
+ NativeFunction,
+ OptionalType,
+ Return,
+ SelfArgument,
+ TensorOptionsArguments,
+ Type,
+)
+from torchgen.utils import assert_never
+
+# This file describes the translation of JIT schema to the public C++
+# API, which is what people use when they call functions like at::add.
+#
+# Prominent characteristics of the C++ API:
+#
+# - dtype, layout, device and pin_memory are collected into
+# a single C++ type TensorOptions (the native functions API
+# also has this, but tensor options is really most relevant
+# for the C++ API; it makes calling kwarg factory functions
+# pleasant)
+#
+# - defaulting lives here (in fact, the dispatcher is completely
+# oblivious of defaults!)
+#
+# BTW: policy on name collisions: we try not to have types with
+# collisions, but functions are fair game to collide
+
+
+def name(
+ func: FunctionSchema,
+ *,
+ faithful_name_for_out_overloads: bool = False,
+ symint_overload: bool = False,
+) -> str:
+ name = str(func.name.name)
+ if symint_overload:
+ name += "_symint"
+ if func.is_out_fn():
+ if faithful_name_for_out_overloads:
+ name += "_outf"
+ else:
+ name += "_out"
+
+ return name
+
+
+# Translation of "value types" in JIT schema to C++ API type. Value
+# types look the same no matter if they are argument types or return
+# types. Returns None if the type in question is not a value type.
+def valuetype_type(
+ t: Type,
+ *,
+ binds: ArgName,
+ remove_non_owning_ref_types: bool = False,
+ symint: bool = False,
+) -> Optional[NamedCType]:
+ if isinstance(t, BaseType):
+ if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
+ return None
+ elif str(t) == "SymInt":
+ if symint:
+ return NamedCType(binds, BaseCType(SymIntT))
+ else:
+ return NamedCType(binds, BaseCType(longT))
+ if remove_non_owning_ref_types:
+ if t.name == BaseTy.str:
+ raise AssertionError(
+ "string ref->value conversion: not implemented yet"
+ )
+ # All other BaseType currently map directly to BaseCppTypes.
+ return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
+ elif isinstance(t, OptionalType):
+ elem = valuetype_type(t.elem, binds=binds, symint=symint)
+ if elem is None:
+ return None
+ return NamedCType(binds, OptionalCType(elem.type))
+ elif isinstance(t, ListType):
+ if str(t.elem) == "bool":
+ assert t.size is not None
+ return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
+ else:
+ return None
+ else:
+ raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+# Translation of types occurring in JIT arguments to a C++ argument type.
+# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
+# For example, we'll return std::vector instead of IntArrayRef.
+# See Note [translation from C++ reference to value types]
+def argumenttype_type(
+ t: Type,
+ *,
+ mutable: bool,
+ binds: ArgName,
+ remove_non_owning_ref_types: bool = False,
+ symint: bool = False,
+) -> NamedCType:
+ # If it's a value type, do the value type translation
+ r = valuetype_type(
+ t,
+ binds=binds,
+ symint=symint,
+ remove_non_owning_ref_types=remove_non_owning_ref_types,
+ )
+ if r is not None:
+ return r
+
+ if isinstance(t, BaseType):
+ if t.name == BaseTy.Tensor:
+ if mutable and not local.use_const_ref_for_mutable_tensors():
+ return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
+ else:
+ return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
+ elif t.name == BaseTy.Scalar:
+ return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
+ else:
+ raise AssertionError(f"base type should have been value type {t}")
+ elif isinstance(t, OptionalType):
+ if str(t.elem) == "Tensor":
+ if mutable and not local.use_const_ref_for_mutable_tensors():
+ return NamedCType(
+ binds, MutRefCType(BaseCType(tensorT))
+ ) # TODO: fix this discrepancy
+ else:
+ return NamedCType(
+ binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
+ )
+ elif str(t.elem) == "Scalar":
+ return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
+ elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
+ return NamedCType(binds, BaseCType(optionalIntArrayRefT))
+ elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt":
+ if symint:
+ return NamedCType(binds, BaseCType(optionalSymIntArrayRefT))
+ else:
+ return NamedCType(binds, BaseCType(optionalIntArrayRefT))
+ elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
+ return NamedCType(binds, OptionalCType(elem.type))
+ elif isinstance(t, ListType):
+ # TODO: remove these special cases, ArrayRef fallthrough works fine
+ if str(t.elem) == "int":
+ if remove_non_owning_ref_types:
+ return NamedCType(binds, VectorCType(BaseCType(longT)))
+ else:
+ return NamedCType(binds, BaseCType(intArrayRefT))
+ if str(t.elem) == "SymInt":
+ if remove_non_owning_ref_types:
+ if symint:
+ return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
+ else:
+ return NamedCType(binds, VectorCType(BaseCType(longT)))
+ else:
+ if symint:
+ return NamedCType(binds, BaseCType(symIntArrayRefT))
+ else:
+ return NamedCType(binds, BaseCType(intArrayRefT))
+ if str(t.elem) == "Tensor":
+ if local.use_ilistref_for_tensor_lists():
+ return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
+ else:
+ return NamedCType(binds, BaseCType(tensorListT))
+ elif str(t.elem) == "Scalar":
+ return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
+ elif str(t.elem) == "Dimname":
+ return NamedCType(binds, BaseCType(dimnameListT))
+ elif str(t.elem) == "Tensor?":
+ return NamedCType(
+ binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
+ )
+ elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
+ return NamedCType(binds, ArrayRefCType(elem.type))
+ else:
+ raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+# Translate a JIT argument into its C++ type
+def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType:
+ return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds)
+
+
+# Translation of a (non-multi) return type from JIT to C++
+# N.B: returntype_type returns a CType, not a NamedCType.
+# This is mostly because of the mismatch between return types and return names.
+# e.g. a function with a return type of 'void' has 0 return names,
+# and a function with a return type of 'std::tuple' has >1 return name.
+def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
+ # placeholder is ignored
+ # NB: symint is ALWAYS respected for return types. So symint argument
+ # here is IGNORED
+ r = valuetype_type(t, binds="__placeholder__", symint=True)
+ if r is not None:
+ return r.type
+
+ if isinstance(t, BaseType):
+ if t.name == BaseTy.Tensor:
+ if mutable:
+ if local.use_const_ref_for_mutable_tensors():
+ return ConstRefCType(BaseCType(tensorT))
+ else:
+ return MutRefCType(BaseCType(tensorT))
+ else:
+ # Note [Tensor Copy Returns]
+ # Currently, we use "Argument.is_write" to determine
+ # whether or not Tensor return types should be copies or references.
+ # If that ever changes, take a look at other locations of this note!
+ return BaseCType(tensorT)
+ elif t.name == BaseTy.Scalar:
+ return BaseCType(scalarT)
+ elif isinstance(t, ListType):
+ assert (
+ not mutable
+ ), "Native functions should never return a mutable tensor list. They should return void."
+ elem = returntype_type(t.elem, mutable=False)
+ assert t.size is None, f"fixed size list returns not supported: {t}"
+ return VectorCType(elem)
+ elif isinstance(t, OptionalType):
+ elem = returntype_type(t.elem, mutable=mutable)
+ if str(t.elem) == "Tensor":
+ return OptionalCType(elem)
+
+ raise AssertionError(f"unrecognized return type {t}")
+
+
+# Translation of a single return to its C++ type
+def return_type(r: Return, *, symint: bool = False) -> CType:
+ return returntype_type(r.type, mutable=r.is_write, symint=symint)
+
+
+# Translation of a full (possibly multi) return from JIT to its C++ type
+def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
+ if len(rs) == 0:
+ return BaseCType(voidT)
+ elif len(rs) == 1:
+ return return_type(rs[0], symint=symint)
+ else:
+ return TupleCType([return_type(r, symint=symint) for r in rs])
+
+
+def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
+ returns: List[str] = []
+ for i, r in enumerate(f.func.returns):
+ # If we have an inplace function, the return argument is
+ # implicitly named self.
+ # TODO: Consider incorporating this into the data model
+ if f.func.name.name.inplace:
+ assert i == 0, "illegal inplace function with multiple returns"
+ name = "self"
+ # If we are out function, the name is the name of the
+ # corresponding output function (r.name will get recorded
+ # in field_name later.)
+ elif f.func.is_out_fn():
+ name = f.func.arguments.out[i].name
+ # If the return argument is explicitly named...
+ elif r.name:
+ name_conflict = any(
+ r.name == a.name for a in f.func.schema_order_arguments()
+ )
+ if name_conflict and not f.func.is_out_fn():
+ name = f"{r.name}_return"
+ else:
+ name = r.name
+ # If there is no explicit name and no fallback name was passed in, we just name the output result,
+ # unless it's a multi-return, in which case it's result0,
+ # result1, etc (zero-indexed)
+ else:
+ name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
+ returns.append(name)
+ return returns
+
+
+JIT_TO_CPP_DEFAULT = {
+ "False": "false",
+ "True": "true",
+ "None": "c10::nullopt", # UGH this one is type directed
+ "Mean": "at::Reduction::Mean",
+ "[]": "{}",
+ "contiguous_format": "MemoryFormat::Contiguous",
+ "long": "at::kLong",
+}
+
+
+# Convert a JIT default into C++ expression representing the default
+def default_expr(d: str, t: Type, *, symint: bool) -> str:
+ if d == "None" and str(t) == "Tensor?":
+ return "{}"
+ if isinstance(t, BaseType) and t.name is BaseTy.str:
+ # Schema allows single quotes but C++ needs double
+ if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
+ s = ""
+ i = 1
+ while i + 1 < len(d):
+ if d[i] != "\\":
+ if d[i] == '"':
+ s += '\\"'
+ else:
+ s += d[i]
+ i += 1
+ else:
+ if d[i + 1] == "'":
+ s += "'"
+ else:
+ s += d[i : i + 2]
+ i += 2
+
+ return f'"{s}"'
+
+ if isinstance(t, OptionalType):
+ if d == "None":
+ return "c10::nullopt"
+
+ return default_expr(d, t.elem, symint=symint)
+
+ if isinstance(t, ListType):
+ if d.startswith("[") and d.endswith("]"):
+ return "{" + d[1:-1] + "}"
+ elif symint and d.isdigit() and str(t.elem) == "SymInt":
+ return f"c10::SymInt({d})"
+ elif t.size is None:
+ # NOTE: Sized lists can have scalar defaults
+ raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
+
+ return JIT_TO_CPP_DEFAULT.get(d, d)
+
+
+# Convert an argument into its C++ API form
+
+
+def argument(
+ a: Union[Argument, TensorOptionsArguments, SelfArgument],
+ *,
+ cpp_no_default_args: Set[str],
+ method: bool,
+ faithful: bool,
+ symint: bool = False,
+ has_tensor_options: bool,
+) -> List[Binding]:
+ def sub_argument(
+ a: Union[Argument, TensorOptionsArguments, SelfArgument]
+ ) -> List[Binding]:
+ return argument(
+ a,
+ cpp_no_default_args=cpp_no_default_args,
+ method=method,
+ faithful=faithful,
+ symint=symint,
+ has_tensor_options=has_tensor_options,
+ )
+
+ if isinstance(a, Argument):
+ binds: ArgName
+ if a.name == "memory_format" and has_tensor_options:
+ binds = SpecialArgName.possibly_redundant_memory_format
+ else:
+ binds = a.name
+ default: Optional[str] = None
+ if a.name not in cpp_no_default_args and a.default is not None:
+ default = default_expr(a.default, a.type, symint=symint)
+ return [
+ Binding(
+ nctype=argument_type(a, binds=binds, symint=symint),
+ name=a.name,
+ default=default,
+ argument=a,
+ )
+ ]
+ elif isinstance(a, TensorOptionsArguments):
+ if faithful:
+ return (
+ sub_argument(a.dtype)
+ + sub_argument(a.layout)
+ + sub_argument(a.device)
+ + sub_argument(a.pin_memory)
+ )
+ else:
+ default = None
+ # Enforced by NativeFunction.__post_init__
+ assert "options" not in cpp_no_default_args
+ if all(x.default == "None" for x in a.all()):
+ default = "{}"
+ elif a.dtype.default == "long":
+ default = "at::kLong" # TODO: this is wrong
+ return [
+ Binding(
+ nctype=NamedCType("options", BaseCType(tensorOptionsT)),
+ name="options",
+ default=default,
+ argument=a,
+ )
+ ]
+ elif isinstance(a, SelfArgument):
+ if method:
+ # Caller is responsible for installing implicit this in context!
+ return []
+ else:
+ return sub_argument(a.argument)
+ else:
+ assert_never(a)
+
+
+def arguments(
+ arguments: Arguments,
+ *,
+ faithful: bool,
+ symint: bool = False,
+ method: bool,
+ cpp_no_default_args: Set[str],
+) -> List[Binding]:
+ args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
+ if faithful:
+ args.extend(arguments.non_out)
+ args.extend(arguments.out)
+ else:
+ args.extend(arguments.out)
+ args.extend(arguments.non_out)
+ return [
+ r.no_default() if faithful else r
+ for a in args
+ for r in argument(
+ a,
+ faithful=faithful,
+ symint=symint,
+ method=method,
+ has_tensor_options=arguments.tensor_options is not None,
+ cpp_no_default_args=cpp_no_default_args,
+ )
+ ]
diff --git a/MLPY/Lib/site-packages/torchgen/api/dispatcher.py b/MLPY/Lib/site-packages/torchgen/api/dispatcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..15f059732893d92d633d58c84fb5c2a5282028a8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/dispatcher.py
@@ -0,0 +1,118 @@
+import itertools
+from typing import List, Sequence, Union
+
+from torchgen.api import cpp
+
+from torchgen.api.types import ArgName, Binding, CType, NamedCType
+from torchgen.model import (
+ Argument,
+ FunctionSchema,
+ Return,
+ SelfArgument,
+ TensorOptionsArguments,
+ Type,
+)
+from torchgen.utils import assert_never, concatMap
+
+# This file describes the translation of JIT schema to the dispatcher
+# API, the *unboxed* calling convention by which invocations through
+# the dispatcher are made. Historically, the dispatcher API matched
+# the C++ API, but with the establishment of the boxed API, we've
+# made changes to the dispatcher API to so that the unboxed API
+# better aligns with the boxed API. The dispatcher API hooks heavily
+# into our template based boxing/unboxing machinery, so changes
+# to this convention will usually need template updates too.
+#
+# Prominent characteristics of the dispatcher API:
+#
+# - dtype, layout, device and pin_memory are represented as separate
+# arguments.
+#
+
+
+def name(func: FunctionSchema) -> str:
+ return cpp.name(func)
+
+
+def argumenttype_type(
+ t: Type,
+ *,
+ mutable: bool,
+ binds: ArgName,
+ remove_non_owning_ref_types: bool = False,
+ symint: bool = True,
+) -> NamedCType:
+ # This is a faux amis. If it makes sense in the future to add
+ # more special cases here, or invert things so cpp.argument_type
+ # calls this, or just completely inline the function, please do
+ # it.
+ return cpp.argumenttype_type(
+ t,
+ mutable=mutable,
+ binds=binds,
+ symint=symint,
+ remove_non_owning_ref_types=remove_non_owning_ref_types,
+ )
+
+
+def argument_type(
+ a: Argument,
+ *,
+ binds: ArgName,
+ remove_non_owning_ref_types: bool = False,
+ symint: bool = True,
+) -> NamedCType:
+ return argumenttype_type(
+ a.type,
+ mutable=a.is_write,
+ binds=binds,
+ remove_non_owning_ref_types=remove_non_owning_ref_types,
+ symint=symint,
+ )
+
+
+def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
+ # At present, there is no difference. But there could be!
+ return cpp.returns_type(rs, symint=symint)
+
+
+def jit_arguments(func: FunctionSchema) -> List[Argument]:
+ def to_argument(
+ a: Union[Argument, TensorOptionsArguments, SelfArgument]
+ ) -> List[Argument]:
+ if isinstance(a, Argument):
+ return [a]
+ elif isinstance(a, SelfArgument):
+ return [a.argument]
+ elif isinstance(a, TensorOptionsArguments):
+ return [a.dtype, a.layout, a.device, a.pin_memory]
+ else:
+ assert_never(a)
+
+ return list(
+ concatMap(
+ to_argument,
+ itertools.chain(
+ func.arguments.positional, func.arguments.kwarg_only, func.arguments.out
+ ),
+ )
+ )
+
+
+def argument(
+ a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True
+) -> Binding:
+ return Binding(
+ nctype=argument_type(
+ a,
+ binds=a.name,
+ remove_non_owning_ref_types=remove_non_owning_ref_types,
+ symint=symint,
+ ),
+ name=a.name,
+ argument=a,
+ )
+
+
+def arguments(func: FunctionSchema, *, symint: bool = True) -> List[Binding]:
+ return [argument(a, symint=symint) for a in jit_arguments(func)]
diff --git a/MLPY/Lib/site-packages/torchgen/api/functionalization.py b/MLPY/Lib/site-packages/torchgen/api/functionalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a31e99ea2e6596e01721e2fa0b63866648ac310
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/functionalization.py
@@ -0,0 +1,199 @@
+from typing import List, Optional
+
+from torchgen.api import dispatcher
+from torchgen.api.types import (
+ BaseCppType,
+ BaseCType,
+ Binding,
+ boolT,
+ ConstRefCType,
+ CType,
+ longT,
+ NamedCType,
+ tensorT,
+)
+from torchgen.model import (
+ Argument,
+ BaseTy,
+ BaseType,
+ FunctionSchema,
+ NativeFunction,
+ NativeFunctionsViewGroup,
+)
+
+
+# This file describes the translation of JIT schema to API's used
+# when creating view lambdas that are used by the functionalization pass.
+# There are two types of lambdas: forward lambdas and reverse lambdas.
+# These API's mostly follow the dispatcher API, with a few quirks:
+# - The lambda capture has to convert reference types to value types
+# - While the forward lambda just directly calls into the at::_ops API
+# (following the dispatcher convention), the logic here for the reverse lambda
+# is responsible for generating both the call-site, and the declarations
+# (which are implemented manually in the at::functionalization::impl namespace).
+
+# The lambdas generated for each view op in the functionalization pass are of the form
+# [capture_arguments](outer_arguments) -> returns_type {
+# return name(inner_arguments);
+# }
+
+# Define some specific lambda input arguments.
+base_binding = Binding(
+ name="base",
+ nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))),
+ argument=Argument(
+ name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
+ ),
+ default=None,
+)
+mutated_view_binding = Binding(
+ name="mutated_view",
+ nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
+ argument=Argument(
+ name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
+ ),
+ default=None,
+)
+mutated_view_idx_binding = Binding(
+ name="mutated_view_idx",
+ nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
+ argument=Argument(
+ name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
+ ),
+ default=None,
+)
+reapply_views_binding = Binding(
+ name="reapply_views",
+ nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)),
+ argument=Argument(
+ name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None
+ ),
+ default=None,
+)
+
+InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode")
+inverse_return_mode_binding = Binding(
+ name="inverse_return_mode",
+ nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)),
+ argument=Argument(
+ name="inverse_return_mode",
+ # NB: not actually a bool but it doesn't matter because this isn't used
+ type=BaseType(BaseTy.bool),
+ default=None,
+ annotation=None,
+ ),
+ default=None,
+)
+
+
+# The lambda capture itself doesn't have a name.
+# The name returned here corresponds to the name of the inner function called by the lambda.
+def name(
+ g: NativeFunctionsViewGroup,
+ *,
+ is_reverse: bool,
+ include_namespace: bool,
+ reapply_views: Optional[bool] = None,
+) -> str:
+ if reapply_views is None:
+ # reapply_views is only important for the fwd lambda,
+ # since we always plumb the runtime "reapply_views" argument into the reverse function.
+ assert is_reverse
+ if is_reverse:
+ return reverse_name(g.view, include_namespace)
+ # in the forward case, we just directly call into the at::_ops API (so we always need the namespace)
+ assert include_namespace
+ assert g.view_copy is not None
+ api_name = (
+ g.view.func.name.unambiguous_name()
+ if reapply_views
+ else g.view_copy.func.name.unambiguous_name()
+ )
+ return f"at::_ops::{api_name}::call"
+
+
+def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
+ # for the reverse: we plumb the "reapply_views" flag into that function and support
+ # both copy and non-copy variants. (We could avoid doing that, but that would require
+ # writing out twice as many view inverse functions).
+ api_name = f.func.name.unambiguous_name()
+ # in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
+ if include_namespace:
+ return f"at::functionalization::FunctionalInverses::{api_name}_inverse"
+ else:
+ return f"{api_name}_inverse"
+
+
+def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> List[Binding]:
+ # capture arguments include all arguments except `self`.
+ # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
+ # So any reference types (IntArrayRef) need to be converted to value types (vector)
+ args = func.arguments.flat_all
+ assert args[0].type == BaseType(BaseTy.Tensor)
+ non_self_args = args[1:]
+ non_self_value_bindings = [
+ dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
+ ]
+
+ all_bindings = [
+ inverse_return_mode_binding if is_reverse else reapply_views_binding
+ ]
+ all_bindings.extend(non_self_value_bindings)
+ return all_bindings
+
+
+def returns_type(func: FunctionSchema) -> CType:
+ # Assertion: all view ops return tensor-like outputs
+ assert len(func.returns) >= 1
+ for ret in func.returns:
+ assert ret.type.is_tensor_like()
+ # However, the return type of the lambda is always an individual tensor.
+ # For multi-tensor outputs, each tensor needs to be tracked individually.
+ return BaseCType(tensorT)
+
+
+def outer_arguments(*, is_reverse: bool) -> List[Binding]:
+ if is_reverse:
+ return [base_binding, mutated_view_binding, mutated_view_idx_binding]
+ else:
+ return [base_binding, mutated_view_idx_binding]
+
+
+def inner_call_index(func: FunctionSchema) -> Optional[Binding]:
+ # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
+ # When we replay a view op that returns multiple tensors, we need to index into the output appropriately
+ if len(func.returns) > 1 or (
+ len(func.returns) == 1 and func.returns[0].type.is_list_like()
+ ):
+ return mutated_view_idx_binding
+ return None
+
+
+def inner_arguments(func: FunctionSchema, is_reverse: bool) -> List[Binding]:
+ args = func.arguments.flat_all
+ assert args[0].type == BaseType(BaseTy.Tensor)
+ non_self_args = args[1:]
+ # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API.
+ # Both of these follow the dispatcher API.
+ non_self_bindings = [dispatcher.argument(a) for a in non_self_args]
+ if not is_reverse:
+ # the forward lambda swaps out the original tensor argument with the lambd arg "base"
+ return [base_binding] + non_self_bindings
+ else:
+ # the reverse lambda does the same, but with an additional "mutated_view" arg
+ # additionally, we have a calling convention: for view ops that return multiple tensor outputs
+ # their corresponding view_inverse function takes in an additional index argument.
+ index_binding = inner_call_index(func)
+ if index_binding is not None:
+ return [
+ base_binding,
+ mutated_view_binding,
+ inverse_return_mode_binding,
+ index_binding,
+ ] + non_self_bindings
+ else:
+ return [
+ base_binding,
+ mutated_view_binding,
+ inverse_return_mode_binding,
+ ] + non_self_bindings
diff --git a/MLPY/Lib/site-packages/torchgen/api/lazy.py b/MLPY/Lib/site-packages/torchgen/api/lazy.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a5ab81faade0655cd3edeae912822910709ef4f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/lazy.py
@@ -0,0 +1,464 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from torchgen.api.types import (
+ BaseCppType,
+ BaseCType,
+ boolT,
+ CType,
+ deviceT,
+ doubleT,
+ generatorT,
+ layoutT,
+ ListCType,
+ longT,
+ memoryFormatT,
+ NamedCType,
+ OptionalCType,
+ scalarT,
+ scalarTypeT,
+ stringT,
+ SymIntT,
+ VectorCType,
+)
+
+from torchgen.model import (
+ Argument,
+ BaseTy,
+ BaseType,
+ FunctionSchema,
+ ListType,
+ OperatorName,
+ OptionalType,
+ Return,
+ TensorOptionsArguments,
+ Type,
+)
+
+
+_valueT: Optional[BaseCppType] = None
+
+
+# A ValueT is an IR type which represents the computation of a Tensor. In other
+# words, a PyTorch user will do operations on lazy tensors, and each output lazy
+# tensor internally tracks a ValueT representing the IR node that would have
+# actually produced the value of this tensor for real.
+#
+# This is configurable because different lazy tensor backends (LTC vs XLA) will
+# have different IR representations. (Though, arguably, after unification they
+# shouldn't!)
+def getValueT() -> BaseCppType:
+ global _valueT
+ if not _valueT:
+ raise NotImplementedError(
+ "The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
+ )
+
+ return _valueT
+
+
+def setValueT(val: BaseCppType) -> None:
+ global _valueT
+ _valueT = val
+
+
+# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
+# making it easier to represent special properties of an arg.
+tensorListValueT = BaseCppType("torch::lazy", "Value")
+
+
+def process_ir_type(
+ typ: Type, properties: "LazyIrProperties", *, symint: bool
+) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
+ """
+ This function takes a type from NativeFunctions and converts it for use with
+ lazy tensor codegen.
+
+ Type conversion for lazy currently consists of
+ (1) changing at::Tensors into lazy::Values
+ (2) wrapping everything in a BaseCType
+ (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
+
+ (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
+ There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'
+
+ This is incomplete- there are assertions in places that it's expected to need to add
+ more types as the codegen is used with more operators.
+ """
+ if isinstance(typ, BaseType):
+ if typ.name == BaseTy.Tensor:
+ return BaseCType(getValueT())
+ elif typ.name == BaseTy.Scalar:
+ if properties.TreatScalarsAsConstants:
+ return BaseCType(scalarT)
+ # at::scalar has special handling,
+ # and is wrapped in an lazy::Value just like at::tensor
+ return BaseCType(getValueT())
+ elif typ.name == BaseTy.ScalarType:
+ return BaseCType(scalarTypeT)
+ elif typ.name == BaseTy.int:
+ return BaseCType(longT)
+ elif typ.name == BaseTy.SymInt:
+ if symint:
+ return BaseCType(getValueT())
+ else:
+ return BaseCType(longT)
+ elif typ.name == BaseTy.bool:
+ return BaseCType(boolT)
+ elif typ.name == BaseTy.float:
+ return BaseCType(doubleT)
+ elif typ.name == BaseTy.str:
+ return BaseCType(stringT)
+ elif typ.name == BaseTy.Device:
+ return BaseCType(deviceT)
+ elif typ.name == BaseTy.Generator:
+ return BaseCType(generatorT)
+ elif typ.name == BaseTy.Layout:
+ return BaseCType(layoutT)
+ elif typ.name == BaseTy.MemoryFormat:
+ return BaseCType(memoryFormatT)
+ else:
+ raise AssertionError(f"TODO add support for type {repr(typ)}")
+ elif isinstance(typ, OptionalType):
+ return OptionalCType(process_ir_type(typ.elem, properties, symint=symint))
+ elif isinstance(typ, ListType):
+ if str(typ.elem) == "Tensor?":
+ # TODO(whc) is this actually correct? or should it use a Vector like above
+ return ListCType(OptionalCType(BaseCType(getValueT())))
+ elif str(typ.elem) == "Tensor":
+ # this is a TensorList which comes in from GetTensorList as a Value
+ return BaseCType(tensorListValueT)
+ elif typ.elem == BaseType(BaseTy.SymInt):
+ # TODO: return a value type. The problem here is analogous to
+ # the problem with tensorListValueT: if you have SymInt[] you
+ # cannot conveniently save the list of Value directly, as nodes
+ # expect to save values as a vector for ALL arguments. So you
+ # need a separate IR node that represents all of the size nodes
+ # assembled into a list. I'm not an LTC dev so I don't want to
+ # figure it out right now. Y'all figure it out...
+ return VectorCType(BaseCType(longT))
+
+ else:
+ return VectorCType(process_ir_type(typ.elem, properties, symint=symint))
+ else:
+ raise AssertionError(f"unrecognized type {repr(typ)}")
+
+
+# TODO: Determining this based off of CType is bad; this should be computed
+# from Type directly; then the same logic as process_ir_type can be used
+#
+# Invariant: passed typ should be an *owning* CType (e.g., we will report
+# that ArrayRef is NOT a value type)
+def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool:
+ """
+ Given a type, determine if it is a Value-like type. This is equivalent to
+ being Tensor-like, but assumes the type has already been transformed.
+ """
+ if isinstance(typ, BaseCType):
+ # I am regretting my naming conventions, but now we are wrapping at::scalar in
+ # lazy value, while preserving other 'scalar' types as scalars in the IR
+ treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
+ return (
+ typ.type == getValueT()
+ or (typ.type == scalarT and not treat_scalars_as_constants)
+ or typ.type == SymIntT
+ )
+ elif typ == VectorCType(BaseCType(SymIntT)):
+ # TODO: report True for this
+ return False
+ elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
+ return isValueType(typ.elem, properties)
+ return False
+
+
+def isSymIntType(typ: Type) -> bool:
+ return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
+
+
+def isWrappedScalarType(typ: Type) -> bool:
+ """
+ Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
+ Since we literally change the type from scalarT to valueT, information is lost.
+ This function helps build a list of wrapped scalars to save that information
+ """
+ if isinstance(typ, BaseType):
+ # I am regretting my naming conventions, but now we are wrapping at::scalar in
+ # lazy value, while preserving other 'scalar' types as scalars in the IR
+ return typ.name == BaseTy.Scalar
+ elif isinstance(typ, (OptionalType, ListType)):
+ return isWrappedScalarType(typ.elem)
+ return False
+
+
+# TODO: dedupe with Type.is_generator_like
+def isGeneratorType(typ: Type) -> bool:
+ if isinstance(typ, BaseType):
+ return typ.name == BaseTy.Generator
+ elif isinstance(typ, (OptionalType)):
+ return isGeneratorType(typ.elem)
+ return False
+
+
+# This class caches a few derived properties computed from an Argument
+# and LazyIrProperties
+class LazyArgument:
+ name: str
+ orig_type: Type
+ lazy_type_: Optional[CType]
+ is_wrapped_scalar: bool
+ is_generator: bool
+ # TODO: this is lies, it is false for symint list
+ is_symint_or_list: bool
+
+ # Whether or not we are treating this as symint or not
+ symint: bool
+
+ # true if this argument is or contains a lazy IR value
+ is_lazy_value: bool
+
+ def __init__(self, arg: Argument, properties: "LazyIrProperties", *, symint: bool):
+ self.name = arg.name
+ self.orig_type = arg.type
+ self.symint = symint
+ self.is_optional = isinstance(arg.type, OptionalType)
+ self.is_generator = isGeneratorType(arg.type)
+ self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
+ self.is_wrapped_scalar = isWrappedScalarType(arg.type)
+ self.is_symint_or_list = symint and (
+ isSymIntType(arg.type)
+ or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem))
+ # TODO: lists of symints are not currently treated as value types
+ # or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
+ )
+
+ self.is_lazy_value = isValueType(self.lazy_type, properties)
+
+ @property
+ def lazy_type(self) -> CType:
+ assert (
+ self.lazy_type_ is not None
+ ), f"Attempted to access lazy_type for invalid argument {self.name}"
+ return self.lazy_type_
+
+
+class LazyIrProperties:
+ """Collection of properties for an IR node
+
+ The property groups are listed below. Each group is mutually
+ exclusive, meaning that only one property from each group can be True
+ at any one time. The properties can be accessed as if they were normal
+ attributes. The mutual exclusivity is automatically handled.
+ """
+
+ Properties: Tuple[Tuple[str, ...], ...] = (
+ (
+ "ShapePrecompute", # Assume shape has been precomputed
+ "ShapeCompute", # Need to compute the shape on construction
+ "ShapeCache", # Utilize the shape cache to defer computation
+ ),
+ (
+ "Lower", # Codegen full lower function
+ "LowerDeclOnly", # Codegen only lower function declaration
+ ),
+ (
+ "CanBeReused", # Codegen full reuse function
+ "CanBeReusedDeclOnly", # Codegen only reuse function declaration
+ ),
+ (
+ "CreateFn", # Codegen full create function
+ "CreateFnDeclOnly", # Codegen only create function declaration
+ ),
+ (
+ "TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values
+ ),
+ )
+
+ def __init__(self, *default_properties: str):
+ properties: Dict[Tuple[str, ...], Optional[str]] = dict.fromkeys(
+ LazyIrProperties.Properties
+ )
+ self.__dict__["properties"] = properties
+ for p in default_properties:
+ setattr(self, p, True)
+
+ def __getattr__(self, key: str) -> Any:
+ properties = self.__dict__["properties"]
+ for values in LazyIrProperties.Properties:
+ if key in values:
+ return properties[values] == key
+
+ return self.__getattribute__(key)
+
+ def __setattr__(self, key: str, value: Any) -> Any:
+ properties = self.__dict__["properties"]
+ for values in LazyIrProperties.Properties:
+ if key in values:
+ properties[values] = key if value else None
+ return value
+
+ raise KeyError(f"Invalid property: {key}")
+
+
+# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
+# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
+# but carries type information from a native FunctionSchema modified for use with IR nodes,
+# and preserving original argument names.
+#
+# TODO: This is not idiomatic with how other torchgen APIs transform on schema.
+class LazyIrSchema:
+ # The name of the operator this function schema describes.
+ name: "OperatorName"
+
+ positional_args: Tuple[LazyArgument, ...]
+ keyword_args: Tuple[LazyArgument, ...]
+
+ # TODO: Need to handle collisions with argument names at some point
+ returns: Tuple["Return", ...]
+
+ # if this schema has a Generator arg, list its orig ctype/name but don't
+ # build a LazyArgument since lazy IR doesn't support it
+ generator_arg: Optional[NamedCType] = None
+
+ # original function schema
+ func: FunctionSchema
+
+ # Whether or not we are code-genning for SymInt or not
+ symint: bool
+
+ properties: LazyIrProperties = LazyIrProperties(
+ # default properties
+ "ShapePrecompute",
+ "Lower",
+ "CanBeReused",
+ )
+ opkind: Optional[str] = None
+
+ def __init__(
+ self,
+ func: FunctionSchema,
+ properties: Optional[LazyIrProperties] = None,
+ *,
+ symint: bool,
+ ):
+ if properties:
+ self.properties = properties
+
+ self.func = func
+ self.symint = symint
+ positional_args: List[LazyArgument] = []
+ for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
+ if arg_field == "self_arg" and func.arguments.self_arg is not None:
+ arg = func.arguments.self_arg.argument
+ positional_args.append(
+ LazyArgument(arg, self.properties, symint=symint)
+ )
+ elif getattr(func.arguments, arg_field) is not None:
+ positional_args.extend(
+ LazyArgument(arg, self.properties, symint=symint)
+ for arg in getattr(func.arguments, arg_field)
+ )
+ self.positional_args = tuple(positional_args)
+
+ keyword_args: List[LazyArgument] = []
+ for arg_field in [
+ "pre_tensor_options_kwarg_only",
+ "tensor_options",
+ "post_tensor_options_kwarg_only",
+ "out",
+ ]:
+ curr_args = getattr(func.arguments, arg_field)
+ if curr_args is not None:
+ if isinstance(curr_args, TensorOptionsArguments):
+ curr_args = curr_args.all()
+ for arg in curr_args:
+ if isGeneratorType(arg.type):
+ assert (
+ self.generator_arg is None
+ ), "We expect there is only one generator arg"
+ self.generator_arg = NamedCType(
+ arg.name, arg.type # type:ignore[arg-type]
+ )
+ keyword_args.extend(
+ LazyArgument(arg, self.properties, symint=symint)
+ for arg in curr_args
+ )
+ self.keyword_args = tuple(keyword_args)
+ self.name = func.name
+ self.returns = func.returns
+
+ @property
+ def node_name(self) -> str:
+ """
+ Return camel-case version of op in node.
+
+ Note: This function also appends any `overload_name` in the operation.
+ For example, if the op is `bitwise_and.Tensor`, the returned name
+ will be `BitwiseAndTensor`.
+ """
+ op_name = f"{self.name.name}_{self.name.overload_name}".lower()
+ return "".join(word.capitalize() or "" for word in op_name.split("_"))
+
+ @property
+ def aten_name(self) -> str:
+ return str(self.name.name)
+
+ @property
+ def base_name(self) -> str:
+ return f"{self.name.name.base}"
+
+ def filtered_args(
+ self,
+ positional: bool = True,
+ keyword: bool = True,
+ values: bool = True,
+ scalars: bool = True,
+ generator: bool = True,
+ ) -> List[LazyArgument]:
+ # This function maintains the sorted order of arguments but provides different filtered views.
+ # Some parts of the code care about kwargs vs args (TS lowerings),
+ # other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
+ # Generators are special cased, as they are needed for fallback/shape-inference but not supported
+ # in TS lowerings and therefore also omitted from lazy IR.
+ args: List[LazyArgument] = []
+ if positional:
+ args.extend(self.positional_args)
+ if keyword:
+ args.extend(self.keyword_args)
+
+ if values and scalars and generator:
+ return args
+ elif values and scalars:
+ return [a for a in args if not a.is_generator]
+ elif values:
+ return [a for a in args if a.is_lazy_value]
+ elif scalars:
+ return [
+ a
+ for a in args
+ if not a.is_lazy_value and (generator or not a.is_generator)
+ ]
+
+ return []
+
+ @property
+ def positional_values(self) -> List[LazyArgument]:
+ return self.filtered_args(
+ positional=True, keyword=False, values=True, scalars=False
+ )
+
+ @property
+ def positional_scalars(self) -> List[LazyArgument]:
+ return self.filtered_args(
+ positional=True, keyword=False, values=False, scalars=True
+ )
+
+ @property
+ def keyword_values(self) -> List[LazyArgument]:
+ return self.filtered_args(
+ positional=False, keyword=True, values=True, scalars=False
+ )
+
+ @property
+ def keyword_scalars(self) -> List[LazyArgument]:
+ return self.filtered_args(
+ positional=False, keyword=True, values=False, scalars=True
+ )
diff --git a/MLPY/Lib/site-packages/torchgen/api/meta.py b/MLPY/Lib/site-packages/torchgen/api/meta.py
new file mode 100644
index 0000000000000000000000000000000000000000..40792a04e9c72397e70118730e55be5e9815723e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/meta.py
@@ -0,0 +1,12 @@
+from torchgen.model import NativeFunctionsGroup
+
+# Follows dispatcher calling convention, but:
+# - Mutable arguments not allowed. Meta functions are always
+# written in functional form. Look at FunctionSchema.signature()
+# - No tensor returns; instead we return a TensorMeta describing
+# the tensor in question
+
+
+def name(g: NativeFunctionsGroup) -> str:
+ # use the overload name from the functional version
+ return str(g.functional.func.name).replace(".", "_")
diff --git a/MLPY/Lib/site-packages/torchgen/api/native.py b/MLPY/Lib/site-packages/torchgen/api/native.py
new file mode 100644
index 0000000000000000000000000000000000000000..1138cb19329b34cf1b781d06d16e856e7e2ec1ef
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/native.py
@@ -0,0 +1,153 @@
+from typing import List, Optional, Sequence, Union
+
+from torchgen import local
+from torchgen.api import cpp
+
+from torchgen.api.types import (
+ ArgName,
+ BaseCType,
+ Binding,
+ boolT,
+ ConstRefCType,
+ CType,
+ deviceT,
+ layoutT,
+ ListCType,
+ MutRefCType,
+ NamedCType,
+ OptionalCType,
+ scalarT,
+ scalarTypeT,
+ tensorT,
+)
+from torchgen.model import (
+ Argument,
+ FunctionSchema,
+ Return,
+ SelfArgument,
+ TensorOptionsArguments,
+ Type,
+)
+from torchgen.utils import assert_never
+
+# This file describes the translation of JIT schema to the native functions API.
+# This looks a lot like the C++ API (which makes historical sense, because the
+# idea was you wrote native functions to implement functions in the C++ API),
+# but over time we have evolved the C++ API without actually changing our
+# native:: kernels. The intention is to make native API and dispatcher API
+# line up as closely as possible, since this results in the least overhead
+# (no translation is needed from dispatcher API to native API).
+#
+# NB: this is symint aware, you will get the non-SymInt variant for some
+# dispatch entries and SymInt for others.
+
+
+def name(func: FunctionSchema) -> str:
+ name = str(func.name.name)
+ # TODO: delete this!
+ if func.is_out_fn():
+ name += "_out"
+ if func.name.overload_name:
+ name += f"_{func.name.overload_name}"
+ return name
+
+
+def argumenttype_type(
+ t: Type, *, mutable: bool, binds: ArgName, symint: bool
+) -> NamedCType:
+ if str(t) == "Tensor?":
+ tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
+ if mutable and not local.use_const_ref_for_mutable_tensors():
+ return NamedCType(binds, MutRefCType(tensor_type))
+ else:
+ return NamedCType(binds, ConstRefCType(tensor_type))
+ elif str(t) == "Tensor?[]":
+ return NamedCType(
+ binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
+ )
+ elif str(t) == "Scalar":
+ return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
+ elif str(t) == "Scalar?":
+ return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
+ return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint)
+
+
+def returns_type(rs: Sequence[Return], *, symint: bool) -> CType:
+ return cpp.returns_type(rs, symint=symint)
+
+
+def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
+ return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint)
+
+
+def argument(
+ a: Union[Argument, SelfArgument, TensorOptionsArguments],
+ *,
+ is_out: bool,
+ symint: bool,
+) -> List[Binding]:
+ # Ideally, we NEVER default native functions. However, there are a number
+ # of functions that call native:: directly and rely on the defaulting
+ # existing. So for BC, we generate defaults for non-out variants (but not
+ # for out variants, where it is impossible to generate an appropriate
+ # default)
+ should_default = not is_out
+ if isinstance(a, Argument):
+ default: Optional[str] = None
+ if should_default and a.default is not None:
+ default = cpp.default_expr(a.default, a.type, symint=symint)
+ return [
+ Binding(
+ nctype=argument_type(a, binds=a.name, symint=symint),
+ name=a.name,
+ default=default,
+ argument=a,
+ )
+ ]
+ elif isinstance(a, SelfArgument):
+ # Erase SelfArgument from the distinction
+ return argument(a.argument, is_out=is_out, symint=symint)
+ elif isinstance(a, TensorOptionsArguments):
+ default = None
+ if should_default:
+ default = "{}"
+ # TODO: Not sure why the arguments assigned here are for
+ # TensorOptionsArguments and not the constituent pieces. It seems
+ # to matter
+ return [
+ Binding(
+ nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))),
+ name="dtype",
+ default=default,
+ argument=a,
+ ),
+ Binding(
+ nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))),
+ name="layout",
+ default=default,
+ argument=a,
+ ),
+ Binding(
+ nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))),
+ name="device",
+ default=default,
+ argument=a,
+ ),
+ Binding(
+ nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))),
+ name="pin_memory",
+ default=default,
+ argument=a,
+ ),
+ ]
+ else:
+ assert_never(a)
+
+
+def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]:
+ args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
+ args.extend(func.arguments.non_out)
+ args.extend(func.arguments.out)
+ return [
+ r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn())
+ ]
diff --git a/MLPY/Lib/site-packages/torchgen/api/python.py b/MLPY/Lib/site-packages/torchgen/api/python.py
new file mode 100644
index 0000000000000000000000000000000000000000..26a6a1f2587ecd7a6ac700940da21bdaf1c7ee70
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/python.py
@@ -0,0 +1,1509 @@
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
+
+from torchgen.api import cpp
+
+from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
+from torchgen.gen import pythonify_default
+from torchgen.model import (
+ Argument,
+ BaseTy,
+ BaseType,
+ FunctionSchema,
+ ListType,
+ NativeFunction,
+ OptionalType,
+ Return,
+ Type,
+ Variant,
+)
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# Data Models
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# [Notes] python binding codegen
+#
+# The Python binding codegen produces code that takes the input list of
+# PyObjects, finds the matching ATen C++ function using PythonArgParser,
+# converts the PyObjects into C++ types and calls the ATen C++ function:
+#
+# +--------+ parsing +------------------------+ binding +-----------------------+
+# | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch |
+# +--------+ +------------------------+ +-----------------------+
+#
+# The following examples demonstrate the data models the Python binding
+# codegen needs to deal with and the tasks it needs to accomplish. It
+# helps understand the purpose of the new data types we introduced below.
+#
+# - Function Schema (source of truth)
+#
+# aten::empty.names(int[] size, *, Dimname[]? names,
+# ScalarType? dtype=None, Layout? layout=None,
+# Device? device=None, bool? pin_memory=None,
+# MemoryFormat? memory_format=None) -> Tensor
+#
+# - Python Signature
+#
+# It's used to generate input schema string for PythonArgParser.
+# Note: TensorOptions fields are reordered and the additional
+# 'requires_grad' field is added:
+#
+# empty(IntArrayRef size, *, DimnameList? names,
+# MemoryFormat? memory_format=None, ScalarType dtype=None,
+# Layout layout=torch.strided, Device device=None,
+# bool pin_memory=False, bool requires_grad=False)
+#
+# - C++ Signature
+#
+# It's used to generate C++ lambda formals & dispatch call.
+# Note: the scattered TensorOptions fields are packed into 'options'.
+#
+# auto dispatch_empty =
+# [](IntArrayRef size, c10::optional names,
+# const TensorOptions & options,
+# c10::optional memory_format) -> Tensor {
+# pybind11::gil_scoped_release no_gil;
+# return torch::empty(size, names, options, memory_format);
+# };
+#
+# - Binding between Python Arguments and C++ Arguments
+#
+# Given a set of Python Arguments in scope, we need produce the
+# binding expressions that translate the Python API into C++ API:
+#
+# Python Args Cpp Args Binding Exprs
+# -----------------------------------------------------------------
+# 0: size size '_r.intlist(0)'
+# 1: names names 'names' [special init]
+# 2: memory_format -------+
+# 3: dtype -----+-|--> options 'options' [special packing]
+# 4: layout / |
+# 5: device / +--> memory_format '_r.memoryformatOptional(2)'
+# 6: pin_memory /
+# 7: requires_grad -+
+#
+# So the full dispatch expression would look like:
+#
+# dispatch_empty(_r.intlist(0), names, options,
+# _r.memoryformatOptional(2))
+#
+# Where does 'names' come from? It involves special local init:
+#
+# auto __names = _r.toDimnameListOptional(1);
+# c10::optional names =
+# __names ? c10::make_optional(DimnameList(__names.value()))
+# : c10::nullopt;
+#
+# Where does 'options' come from? It involves special local init
+# for TensorOptions. Note that Python side has the additional
+# 'requires_grad' field:
+#
+# const auto options = TensorOptions()
+# .dtype(_r.scalartype(3))
+# .device(_r.device(5))
+# .layout(_r.layoutOptional(4))
+# .requires_grad(_r.toBool(7))
+# .pinned_memory(_r.toBool(6));
+#
+# In some other cases one Python Argument can map to multiple C++
+# Arguments. For example:
+#
+# aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False)
+# -> (Tensor values, Tensor indices)
+#
+# Python Args Cpp Args Binding Exprs
+# ---------------------------------------------------------------------
+# +----> max 'out[0]'
+# /-----> max_values 'out[1]
+# 0: input / self '_r.tensor(0)'
+# 1: dim / dim '_r.dimname(1)'
+# 2: keepdim / keepdim '_r.toBool(2)'
+# 3: out -----+ [local init] out '_r.tensorlist_n<2>(3)'
+#
+# As demonstrated above, the binding can involve reordering,
+# packing, unpacking and special local inits.
+#
+#
+# Let's look at a concrete example:
+#
+# static PythonArgParser parser({
+# "abs(Tensor input, *, Tensor out=None)",
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# ^
+# +--- Python Schema, represented by PythonSignature and PythonArgument
+#
+# }, /*traceable=*/true);
+#
+# ParsedArgs<2> parsed_args;
+# auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
+#
+# ...
+#
+# if (_r.isNone(1)) {
+# ~~~~~~~~~~~~ <--- Scattered PythonArgParser output (arg name = 'out')
+# represented by PythonArgParserOutputExpr
+#
+# // aten::abs(Tensor self) -> Tensor
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# ^
+# +--- NativeFunction schema, base version
+#
+# auto dispatch_abs = [](const Tensor & self) -> Tensor {
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# ^
+# +--- dispatch_lambda_args / dispatch_lambda_return_str
+# generated from NativeFunction / CppSignature
+# (deprecated PythonSignature is special)
+# arguments are represented by DispatchLambdaArgument
+#
+# pybind11::gil_scoped_release no_gil;
+# return self.abs();
+# ~~~~~~~~~~~ <--- cpp_dispatch_target / cpp_dispatch_exprs
+# generated from NativeFunction / CppSignature
+# };
+# return wrap(dispatch_abs(_r.tensor(0)));
+# ~~~~~~~~~~~~~
+# ^
+# +--- dispatch_lambda_exprs
+# binding PythonArgParserOutputExpr (python args)
+# and DispatchLambdaArgument (c++ args)
+#
+# } else {
+# // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# ^
+# +--- NativeFunction schema, out-variant
+#
+# auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor {
+# pybind11::gil_scoped_release no_gil;
+# return at::abs_out(out, self);
+# };
+# return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0)));
+# }
+#
+#
+# [Notes] python interface codegen
+# The python dataclasses below are used used to generate both python binding code
+# and pyi type hint signatures.
+# In theory these two should look very similar, but there are number of differences
+# in how pyi signatures vs. python_arg_parser signatures are generated.
+# These differences have been encapsulated in signature_str() vs. signature_str_pyi()
+# to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments.
+# For examples, only pyi signatures include return types.
+
+
+@dataclass(frozen=True)
+class PythonReturns:
+ returns: Tuple[Return, ...]
+
+
+@dataclass(frozen=True)
+class PythonArgument:
+ name: str
+ type: Type
+ default: Optional[str]
+
+ # Used to generate the default init expr for some PythonArgParser outputs, e.g.:
+ #
+ # _r.layoutWithDefault(3, layout_from_backend(self.options().backend())))
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ # ^
+ # +--- default_init str
+ default_init: Optional[str]
+
+ # Compute argument formal for python argument parsing.
+ # Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
+ def argument_str(self, *, method: bool = False, symint: bool = True) -> str:
+ type_str = (
+ argument_type_str(self.type, symint=symint)
+ .replace("const ", "")
+ .replace(" &", "")
+ )
+
+ name = self.name
+ # s/self/input/ outside method bindings
+ # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
+ # for the parse string
+ if name == "self" and type_str in ["Tensor", "Number"] and not method:
+ name = "input"
+
+ # add default
+ if self.default is not None:
+ default = {
+ "nullptr": "None",
+ "c10::nullopt": "None",
+ "{}": "None",
+ }.get(self.default, self.default)
+ return f"{type_str} {name}={default}"
+ else:
+ return f"{type_str} {name}"
+
+ def argument_str_pyi(
+ self, *, method: bool = False, deprecated: bool = False
+ ) -> str:
+ type_str = argument_type_str_pyi(self.type)
+
+ name = self.name
+ # s/self/input/ outside method bindings
+ # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
+ # for the parse string
+ if name == "self" and type_str == "Tensor" and not method and not deprecated:
+ name = "input"
+
+ if name == "from": # from is a Python keyword...
+ name += "_"
+
+ # pyi merges the _out and functional variants into the same signature, with an optional out arg
+ if name == "out" and type_str == "Tensor" and not deprecated:
+ type_str = "Optional[" + type_str + "]"
+
+ # pyi deprecated signatures don't get defaults for their out arg
+ treat_as_no_default = (
+ deprecated
+ and isinstance(self, PythonOutArgument)
+ and self.default == "None"
+ )
+
+ # add default
+ if self.default is not None and not treat_as_no_default:
+ if (
+ isinstance(self.type, ListType)
+ and self.type.elem == BaseType(BaseTy.int)
+ and self.default.startswith("{")
+ and self.default.endswith("}")
+ ):
+ default = "(" + self.default[1:-1] + ")"
+ else:
+ default = {
+ "nullptr": "None",
+ "c10::nullopt": "None",
+ "{}": "None",
+ "MemoryFormat::Contiguous": "contiguous_format",
+ "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
+ }.get(self.default, self.default)
+ return f"{name}: {type_str} = {default}"
+ else:
+ return f"{name}: {type_str}"
+
+
+@dataclass(frozen=True)
+class PythonOutArgument(PythonArgument):
+ # In Python signature multiple output fields are packed into one 'out' argument.
+ # When binding to C++, it's first binded to a local 'out' variable:
+ # 'auto out = _r.tensorlist_n<2>(2);',
+ # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
+ # TODO: maybe don't need keep scattered out fields for python signature?
+ outputs: Tuple[PythonArgument, ...]
+
+ @staticmethod
+ def from_outputs(
+ outputs: Tuple[PythonArgument, ...]
+ ) -> Optional["PythonOutArgument"]:
+ if not outputs:
+ return None
+
+ size = len(outputs)
+ if size == 1:
+ return PythonOutArgument(
+ name=outputs[0].name,
+ type=outputs[0].type,
+ default="None",
+ default_init=None,
+ outputs=outputs,
+ )
+ elif size > 1:
+ if any(not a.type.is_tensor_like() for a in outputs):
+ raise RuntimeError(f"Unsupported output type: {outputs}")
+ return PythonOutArgument(
+ name="out",
+ # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None?
+ type=ListType(BaseType(BaseTy.Tensor), size),
+ default="None",
+ default_init=None,
+ outputs=outputs,
+ )
+ raise AssertionError(r"Unexpected PythonOutArgument size")
+
+
+@dataclass(frozen=True)
+class PythonSignature:
+ # Base operator name, without inplace/outplace suffix.
+ name: str
+
+ # Positional arguments.
+ # TODO: create a dedicated SelfArgument type for 'self'?
+ input_args: Tuple[PythonArgument, ...]
+
+ # Keyword arguments excluding the 'out' argument and scattered kwargs belonging
+ # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
+ input_kwargs: Tuple[PythonArgument, ...]
+
+ output_args: Optional[PythonOutArgument]
+
+ # Return types, which are only used by pyi
+ returns: PythonReturns
+
+ # These are scattered kwargs arguments belonging to TensorOptions.
+ # When binding to C++, they are packed into a TensorOptions object 'options'.
+ # It's possible that the C++ signature doesn't take TensorOptions object (e.g.
+ # for out variant), in which case they will be used as scattered fields without
+ # being packed into 'options'.
+ # TODO: maybe create a PythonTensorOptionsArgument?
+ tensor_options_args: Tuple[PythonArgument, ...]
+
+ # method or function signature?
+ method: bool
+
+ @property
+ def deprecated(self) -> bool:
+ return False
+
+ def arguments(
+ self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
+ ) -> Tuple[Union[PythonArgument, PythonOutArgument], ...]:
+ result: List[Union[PythonArgument, PythonOutArgument]] = []
+ result.extend(self.input_args)
+ result.extend(self.input_kwargs)
+ if self.output_args is not None and not skip_outputs:
+ result.append(self.output_args)
+ if not skip_tensor_options:
+ result.extend(self.tensor_options_args)
+ return tuple(result)
+
+ def arguments_count(self) -> int:
+ return len(self.arguments())
+
+ def output_idx(self) -> int:
+ return len(self.input_args) + len(self.input_kwargs)
+
+ # [old codegen] Compute the Python function signature for argument parsing,
+ # as specified in torch/csrc/utils/python_arg_parser.h. WARNING:
+ # this is NOT the same type signature as specified by PEP 484
+ # as understood by mypy; our format was independently developed
+ # and has some quirks to make it more suitable specifically
+ # for error parsing.
+ #
+ # For a translation to mypy-valid type signatures, see
+ # signature_str_pyi().
+ def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
+ args = self.arguments(skip_outputs=skip_outputs)
+ schema_formals: List[str] = [
+ a.argument_str(method=self.method, symint=symint) for a in args
+ ]
+ positional_argc = len(self.input_args)
+ if len(schema_formals) > positional_argc:
+ schema_formals.insert(positional_argc, "*")
+
+ return f'{self.name}({", ".join(schema_formals)})'
+
+ def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
+ args = self.arguments(skip_outputs=skip_outputs)
+ schema_formals: List[str] = [
+ a.argument_str_pyi(method=self.method) for a in args
+ ]
+ positional_argc = len(self.input_args)
+ if len(schema_formals) > positional_argc:
+ schema_formals.insert(positional_argc, "*")
+
+ # only pyi signatures include returns
+ returns_str = returns_str_pyi(self)
+ # pyi also includes self (with no typing/defaults) for methods
+ if self.method:
+ schema_formals.insert(0, "self")
+ return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
+
+ def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
+ # only pyi uses vararg signatures
+ args = self.arguments(skip_outputs=skip_outputs)
+ schema_formals: List[str] = [
+ a.argument_str_pyi(method=self.method) for a in args
+ ]
+ # vararg only applies to pyi signatures. vararg variants are not generated for all signatures
+ num_args = self.arguments_count()
+ num_positionalargs = len(self.input_args)
+
+ have_vararg_version = False
+ if num_args > 0:
+ vararg_type = args[0].type
+ if (
+ isinstance(vararg_type, ListType)
+ and str(vararg_type.elem) in ["int", "SymInt"]
+ and num_positionalargs == 1
+ ):
+ have_vararg_version = True
+
+ if not have_vararg_version:
+ return None
+ # Below are the major changes in vararg vs. regular pyi signatures
+ # vararg signatures also omit the asterix
+ schema_formals[0] = "*" + args[0].name + ": _int"
+
+ returns_str = returns_str_pyi(self)
+ # pyi also includes self (with no typing/defaults) for methods
+ if self.method:
+ schema_formals.insert(0, "self")
+ return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
+
+
+# The deprecated python signature involves some special logic, so create a
+# dedicated data model to store these extra properties.
+@dataclass(frozen=True)
+class PythonSignatureDeprecated(PythonSignature):
+ # Schema for the deprecated function
+ deprecated_schema: FunctionSchema
+
+ # The deprecated signature might miss some arguments that the corresponding
+ # C++ signature expects. We need store the constant default values to pass in.
+ # For example:
+ # [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2)
+ # [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+ # [func call]: self.addmm(mat1, mat2, beta, 1)
+ # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
+ deprecated_args_exprs: Tuple[str, ...]
+
+ @property
+ def deprecated(self) -> bool:
+ return True
+
+ def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
+ return (
+ PythonSignature.signature_str(
+ self, skip_outputs=skip_outputs, symint=symint
+ )
+ + "|deprecated"
+ )
+
+ def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
+ args = self.arguments(skip_outputs=skip_outputs)
+ schema_formals: List[str] = [
+ a.argument_str_pyi(method=self.method, deprecated=True) for a in args
+ ]
+ positional_argc = len(self.input_args)
+ if len(schema_formals) > positional_argc:
+ schema_formals.insert(positional_argc, "*")
+
+ returns_str = returns_str_pyi(self)
+ return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
+
+ def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
+ # the codegen doesn't include vararg variants for deprecated signatures
+ return None
+
+
+# This struct is used to hold the PythonSignature and its corresponding
+# NativeFunction BEFORE grouping base and out-variant functions.
+# Why not store NativeFunction in PythonSignature or construct PythonSignature
+# from NativeFunction? Because they are not 1-1 mapped.
+# One native function could have both deprecated and non-deprecated python
+# signatures - NativeFunction doesn't contain information to construct the
+# deprecated python signature.
+# One python signature is used to handle both the base and the out-variant
+# function - see 'PythonSignatureGroup'.
+@dataclass(frozen=True)
+class PythonSignatureNativeFunctionPair:
+ signature: PythonSignature
+ function: NativeFunction
+
+
+# We merge pairs of functions with signatures that are equivalent mod
+# output arguments, and use a single entry in the python_arg_parser sig
+# list for both (output arguments become optional).
+@dataclass(frozen=True)
+class PythonSignatureGroup:
+ # The signature used for Python argument parsing. The outplace signature
+ # is preferred if exists, because it can be used to parse inputs for both
+ # the out-place variant and the base version (with output omitted).
+ signature: PythonSignature
+
+ # The regular ATen declaration (e.g. conv2d)
+ base: NativeFunction
+
+ # The out variant (e.g. conv2d_out)
+ outplace: Optional[NativeFunction]
+
+ @classmethod
+ def from_pairs(
+ cls,
+ functional: PythonSignatureNativeFunctionPair,
+ out: Optional[PythonSignatureNativeFunctionPair],
+ ) -> "PythonSignatureGroup":
+ if out is None:
+ return PythonSignatureGroup(
+ signature=functional.signature,
+ base=functional.function,
+ outplace=None,
+ )
+
+ # prefer the signature with optional out=... arguments because it's the
+ # superset that can be used to parse input for both base and outplace.
+ signature_kwargs = out.signature.__dict__.copy()
+
+ # Out overloads in C++ don't have TensorOptions arguments,
+ # so take these from the functional variant
+ signature_kwargs[
+ "tensor_options_args"
+ ] = functional.signature.tensor_options_args
+
+ return PythonSignatureGroup(
+ signature=type(out.signature)(**signature_kwargs),
+ base=functional.function,
+ outplace=out.function,
+ )
+
+
+# C++ function dispatch is wrapped in a lambda function. The lambda function
+# has almost the same signature as the C++ function, only with some small
+# variants - see details below.
+# This data model is used to represent arguments of the lambda function
+# signature.
+@dataclass(frozen=True)
+class DispatchLambdaArgument:
+ name: str
+ type_str: str
+ is_out_arg: bool
+
+
+# To pass PyObjects arguments to C++ function (via the lambda wrapper),
+# we need first convert PyObjects into simple C++ objects. This work
+# is done by PythonArgParser.
+# This data model is used to represent the output of PythonArgParser.
+# It has 1-1 mapping with PythonArgument in PythonSignature.
+@dataclass(frozen=True)
+class PythonArgParserOutputExpr:
+ # argument name
+ name: str
+
+ # RHS expression to reference PythonArgParser output.
+ expr: str
+
+ # In some special cases we need create different expr, e.g.:
+ # '_r.isNone(1)' instead of '_r.tensor(1)'.
+ index: int
+
+ # The python argument it maps to.
+ argument: PythonArgument
+
+ @property
+ def is_none_expr(self) -> str:
+ return f"_r.isNone({self.index})"
+
+
+# To pass PythonArgParser output to the lambda wrapper, we need bind
+# PythonArgParserOutputExpr to DispatchLambdaArgument.
+# They are not always 1-1 mapped, e.g. scattered TensorOptions fields
+# need be packed into a TensorOptions object, which is the argument
+# that the lambda function wrapper takes.
+@dataclass(frozen=True)
+class DispatchLambdaArgumentExprs:
+ # The exprs that provide the binding for lambda arguments, e.g.:
+ #
+ # 'self' -> '_r.tensor(0)'
+ # 'min' -> 'out[0]' / 'min_indices' -> 'out[1]'
+ # 'options' -> 'options'
+ #
+ # It has 1-1 mapping with DispatchLambdaArgument.
+ exprs: Sequence[str]
+
+ # Special local inits, which might introduce new variables that
+ # the 'exprs' above reference, e.g.:
+ #
+ # 'auto out = _r.tensorlist_n<2>(2);'
+ #
+ inits: Sequence[str]
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# Helper Functions
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature:
+ return CppSignatureGroup.from_native_function(f, method=method).signature
+
+
+def has_tensor_options(f: NativeFunction) -> bool:
+ return f.func.arguments.tensor_options is not None
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# Python Signature
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+# 'simple_type' was introduced by the old codegen, which is slightly
+# different from the python schema type, e.g.: doesn't have '?' suffix
+# for optional Tensor/TensorList; doesn't have '[size]' suffix for list type.
+def argument_type_str(
+ t: Type, *, simple_type: bool = False, symint: bool = True
+) -> str:
+ if isinstance(t, BaseType):
+ if t.name == BaseTy.Tensor:
+ return "Tensor"
+ elif t.name == BaseTy.int:
+ return "int64_t"
+ elif t.name == BaseTy.float:
+ return "double"
+ elif t.name == BaseTy.str:
+ return "c10::string_view"
+ elif t.name in [
+ BaseTy.bool,
+ BaseTy.QScheme,
+ BaseTy.Scalar,
+ BaseTy.ScalarType,
+ BaseTy.Generator,
+ BaseTy.Storage,
+ BaseTy.Layout,
+ BaseTy.Device,
+ BaseTy.DeviceIndex,
+ BaseTy.MemoryFormat,
+ BaseTy.Dimname,
+ BaseTy.Stream,
+ BaseTy.ConstQuantizerPtr,
+ BaseTy.SymInt,
+ ]:
+ # These python schema type names line up with their function schema names
+ return t.name.name
+
+ elif isinstance(t, OptionalType):
+ if str(t.elem) == "Tensor":
+ # Is it desired to keep '?' for simple_type with new style dispatcher?
+ return "Tensor?"
+ elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
+ return f"{elem}?"
+ elif isinstance(t, ListType):
+ size = t.size if not simple_type else None
+ if str(t.elem) == "bool":
+ assert t.size is not None
+ return f"::std::array"
+ elif str(t.elem) == "int":
+ return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
+ elif str(t.elem) == "SymInt":
+ if symint:
+ return (
+ f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
+ )
+ else:
+ return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
+ elif str(t.elem) == "Tensor":
+ return f"TensorList[{size}]" if size is not None else "TensorList"
+ elif str(t.elem) == "Scalar":
+ return f"ScalarList[{size}]" if size is not None else "ScalarList"
+ elif str(t.elem) == "Tensor?":
+ if simple_type:
+ return "c10::List>"
+ else:
+ return "const c10::List> &"
+ elif str(t.elem) == "Dimname":
+ return f"DimnameList[{size}]" if size is not None else "DimnameList"
+ elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
+ return f"ArrayRef<{elem}>"
+
+ raise RuntimeError(f"unrecognized type {repr(t)}")
+
+
+def argument_type_size(t: Type) -> Optional[int]:
+ l = t.is_list_like()
+ if l is not None and str(l.elem) != "bool":
+ return l.size
+ else:
+ return None
+
+
+def argument(a: Argument) -> PythonArgument:
+ return PythonArgument(
+ name=a.name,
+ type=a.type,
+ # TODO: directly translate a.default to python default
+ default=str(
+ pythonify_default(cpp.default_expr(a.default, a.type, symint=False))
+ )
+ if a.default is not None
+ else None,
+ default_init=None,
+ )
+
+
+# Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen
+def signature(
+ f: NativeFunction, *, method: bool = False, pyi: bool = False
+) -> PythonSignature:
+ return signature_from_schema(
+ f.func, category_override=f.category_override, method=method, pyi=pyi
+ )
+
+
+def signature_from_schema(
+ func: FunctionSchema,
+ *,
+ category_override: Optional[str],
+ method: bool = False,
+ pyi: bool = False,
+) -> PythonSignature:
+ args: List[Argument] = []
+ args.extend(func.arguments.pre_self_positional)
+ # Skip SelfArgument if this is method.
+ if not method and func.arguments.self_arg is not None:
+ args.append(func.arguments.self_arg.argument)
+ args.extend(func.arguments.post_self_positional)
+ args.extend(func.arguments.pre_tensor_options_kwarg_only)
+ # Skip TensorOptionsArguments. Python side TensorOptions
+ # arguments are created based on different rules - see below.
+ args.extend(func.arguments.post_tensor_options_kwarg_only)
+ args.extend(func.arguments.out)
+
+ input_arg_set = {a.name for a in func.arguments.flat_positional}
+ kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only}
+ out_arg_set = {a.name for a in func.arguments.out}
+
+ input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args)))
+ input_kwargs = tuple(
+ map(argument, filter(lambda a: a.name in kwarg_only_set, args))
+ )
+ outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args)))
+
+ # Reintroduce the scattered fields of TensorOptions for Python.
+ # Compared to the cpp counterpart, the python arguments have new property
+ # (default_init) and a new argument 'requires_grad', which require some
+ # special handlings.
+ # [old codegen] TODO: because these aren't guaranteed to be 100% faithful
+ # to the original versions in the yaml, this recreation is a potential
+ # source of drift between eager and JIT. Pull this logic out to a shared place.
+
+ has_tensor_input_arg = any(
+ a.type.is_tensor_like() for a in func.arguments.flat_non_out
+ )
+ if any(a.name == "requires_grad" for a in func.schema_order_arguments()):
+ raise ValueError(
+ "argument named requires_grad is reserved, should not explicitly add it in the schema"
+ )
+
+ # [old codegen] this probably won't work if one of the returns is not a tensor,
+ # but it will produce a compile-time error that is obvious.
+ has_tensor_return = any(r.type.is_tensor_like() for r in func.returns)
+
+ name: str = cpp.name(func)
+ is_factory_function = category_override == "factory" or (
+ has_tensor_return and not has_tensor_input_arg
+ )
+ is_like_or_new_function = (
+ category_override in ("new", "like")
+ or name.startswith("new_")
+ or name.endswith("_like")
+ )
+ is_dummy_function = category_override == "dummy"
+
+ tensor_options_args: List[PythonArgument] = []
+ if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
+
+ def topt_default_init(name: str) -> Optional[str]:
+ topt_args = func.arguments.tensor_options
+ if topt_args is None:
+ return None
+ a = getattr(topt_args, name)
+ if a.default is None or a.default == "None":
+ return None
+ return cpp.default_expr(a.default, a.type, symint=False)
+
+ tensor_options_args.append(
+ PythonArgument(
+ name="dtype",
+ type=OptionalType(BaseType(BaseTy.ScalarType)),
+ default="None",
+ default_init=(
+ None if is_like_or_new_function else topt_default_init("dtype")
+ ),
+ )
+ )
+ tensor_options_args.append(
+ PythonArgument(
+ name="layout",
+ type=OptionalType(BaseType(BaseTy.Layout)),
+ default="None",
+ default_init=(
+ None if is_like_or_new_function else topt_default_init("layout")
+ ),
+ )
+ )
+ tensor_options_args.append(
+ PythonArgument(
+ name="device",
+ type=OptionalType(BaseType(BaseTy.Device)),
+ default="None",
+ default_init=(
+ None
+ if is_like_or_new_function
+ else (
+ topt_default_init("device")
+ or "torch::tensors::get_default_device()"
+ )
+ ),
+ )
+ )
+ tensor_options_args.append(
+ PythonArgument(
+ name="pin_memory",
+ type=OptionalType(BaseType(BaseTy.bool)),
+ default="False",
+ default_init=None,
+ )
+ )
+ tensor_options_args.append(
+ PythonArgument(
+ name="requires_grad",
+ type=OptionalType(BaseType(BaseTy.bool)),
+ default="False",
+ default_init=None,
+ )
+ )
+
+ returns = PythonReturns(returns=func.returns)
+
+ return PythonSignature(
+ name=str(func.name.name),
+ input_args=input_args,
+ input_kwargs=input_kwargs,
+ output_args=PythonOutArgument.from_outputs(outputs),
+ tensor_options_args=tuple(tensor_options_args),
+ returns=returns,
+ method=method,
+ )
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# Python Interface
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+def structseq_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
+ if len(returns) <= 1 or all(r.name is None for r in returns):
+ return []
+ else:
+ if any(r.name is None for r in returns):
+ # When building on Windows, `PyStructSequence_UnnamedField` could not be
+ # resolved by the linker for some reason, which cause error in building:
+ #
+ # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
+ # PyStructSequence_UnnamedField
+ #
+ # Thus, at this point in time, we do not support unnamed
+ # fields in structseq; you must either name all fields,
+ # or none of them.
+ raise ValueError("Unnamed field is not supported by codegen")
+
+ return [str(r.name) for r in returns]
+
+
+def argument_type_str_pyi(t: Type) -> str:
+ add_optional = False
+ if isinstance(t, OptionalType):
+ t = t.elem
+ add_optional = True
+
+ if isinstance(t, BaseType):
+ if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
+ ret = "_int"
+ if t.name == BaseTy.SymInt:
+ ret = "Union[_int, SymInt]"
+ elif t.name == BaseTy.float:
+ ret = "_float"
+ elif t.name == BaseTy.str:
+ ret = "str"
+ elif t.name == BaseTy.Scalar:
+ ret = "Union[Number, _complex]"
+ elif t.name == BaseTy.ScalarType:
+ ret = "_dtype"
+ elif t.name == BaseTy.bool:
+ ret = "_bool"
+ elif t.name == BaseTy.QScheme:
+ ret = "_qscheme"
+ elif t.name == BaseTy.Layout:
+ ret = "_layout"
+ elif t.name == BaseTy.Device:
+ ret = "Optional[DeviceLikeType]"
+ elif t.name == BaseTy.MemoryFormat:
+ ret = "memory_format"
+ elif t.name == BaseTy.Dimname:
+ ret = "Union[str, ellipsis, None]"
+ elif t.name == BaseTy.Storage:
+ ret = "Union[Storage, UntypedStorage]"
+ elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]:
+ # These python schema type names line up with their function schema names
+ ret = t.name.name
+
+ elif isinstance(t, ListType):
+ if str(t.elem) == "int":
+ ret = "Union[_int, _size]" if t.size is not None else "_size"
+ elif t.is_tensor_like():
+ # TODO: this doesn't seem right...
+ # Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]]
+ # It should probably translate to Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]]
+ if isinstance(t.elem, OptionalType):
+ add_optional = True
+ ret = (
+ "Union[Tensor, Tuple[Tensor, ...], List[Tensor]]"
+ if t.size is not None
+ else "Union[Tuple[Tensor, ...], List[Tensor]]"
+ )
+ elif str(t.elem) == "float":
+ ret = "Sequence[_float]"
+ elif str(t.elem) == "SymInt" and t.size is not None:
+ elem = argument_type_str_pyi(t.elem)
+ ret = f"Union[{elem}, Sequence[{elem}]]"
+ else:
+ elem = argument_type_str_pyi(t.elem)
+ ret = f"Sequence[{elem}]"
+
+ else:
+ raise RuntimeError(f"unrecognized type {repr(t)}")
+
+ if add_optional:
+ ret = "Optional[" + ret + "]"
+
+ return ret
+
+
+def return_type_str_pyi(t: Type) -> str:
+ # Where arguments are open to accepting Union, return types should return
+ # concrete types
+
+ if isinstance(t, OptionalType):
+ inner = return_type_str_pyi(t.elem)
+ return f"Optional[{inner}]"
+
+ if isinstance(t, BaseType):
+ if t.name == BaseTy.Device:
+ return "_device"
+ elif t.name == BaseTy.Dimname:
+ ret = "Optional[str]"
+ else:
+ return argument_type_str_pyi(t)
+
+ if isinstance(t, ListType):
+ inner = return_type_str_pyi(t.elem)
+ return f"Tuple[{inner}, ...]"
+
+ return argument_type_str_pyi(t)
+
+
+def returns_structseq_pyi(signature: PythonSignature) -> Optional[Tuple[str, str]]:
+ python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
+ structseq_name = signature.name
+ field_names = structseq_fieldnames(signature.returns.returns)
+ if field_names:
+ # These types are structseq objects which act like named NamedTuples, but
+ # the constructor acts like the constructor of tuple. Using typing.NamedTuple
+ # does not allow us to override __init__.
+ field_names_str = ", ".join(repr(name) for name in field_names)
+ seq_type = f"Tuple[{', '.join(python_returns)}]"
+ structseq_def_lines = [
+ f"class {structseq_name}({seq_type}):",
+ ]
+ for name, typ in zip(field_names, python_returns):
+ structseq_def_lines.extend(
+ [
+ " @property",
+ f" def {name}(self) -> {typ}: ...",
+ ]
+ )
+ structseq_def_lines.extend(
+ [
+ f" def __new__(cls, sequence: {seq_type}): ...",
+ f" n_fields: _int = {len(field_names)}",
+ f" n_sequeunce_fields: _int = {len(field_names)}",
+ " n_unnamed_fields: _int = 0",
+ " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
+ "", # add an extra newline
+ ]
+ )
+ structseq_def = "\n".join(structseq_def_lines)
+ # Example:
+ # structseq_def = (
+ # "class max(Tuple[Tensor, Tensor]):\n"
+ # " @property\n"
+ # " def values(self) -> Tensor: ...\n"
+ # " @property\n"
+ # " def indices(self) -> Tensor: ...\n"
+ # " def __new__(cls, sequence: Tuple[Tensor, Tensor]): ...\n"
+ # " n_fields: _int = 2",
+ # " n_sequeunce_fields: _int = 2",
+ # " n_unnamed_fields: _int = 0",
+ # " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
+ # )
+ return structseq_name, structseq_def
+ return None
+
+
+def returns_str_pyi(signature: PythonSignature) -> str:
+ field_names = structseq_fieldnames(signature.returns.returns)
+ if field_names:
+ return f"torch.return_types.{signature.name}"
+
+ python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
+ if len(python_returns) > 1:
+ return "Tuple[" + ", ".join(python_returns) + "]"
+ if len(python_returns) == 1:
+ return python_returns[0]
+ return "None"
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# C++ Function Dispatch
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+# This section provides APIs to generate the code that does C++ function
+# dispatch. The C++ function call is wrapped by a lambda function.
+# For example:
+#
+# // aten::selu_(Tensor(a!) self) -> Tensor(a!)
+# auto dispatch_selu_ = [](Tensor self) -> Tensor {
+# pybind11::gil_scoped_release no_gil;
+# return at::selu_(self);
+# };
+#
+# The lambda function's signature follows the C++ signature in common
+# cases, e.g.:
+#
+# // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+# [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
+#
+# For out variant the 'out' argument's type is changed from 'Tensor &'
+# to 'Tensor'. It's because when calling the lambda it passes in the
+# PythonArgParser output '_r.tensor(3)', which is stack allocated object
+# and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'.
+#
+# // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+# [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
+#
+# For multi-output case it can keep using reference type because the
+# PythonArgParser output has been unpacked to local variables, e.g.:
+#
+# // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *,
+# // Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
+# [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple
+#
+# For deprecated python signature, it should follow deprecated python arg order.
+# TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary?
+
+
+def dispatch_lambda_args(
+ ps: PythonSignature, f: NativeFunction, symint: bool = True
+) -> Tuple[DispatchLambdaArgument, ...]:
+ if isinstance(ps, PythonSignatureDeprecated):
+ schema = ps.deprecated_schema
+ else:
+ schema = f.func
+
+ # Start with cpp arguments - dispatch lambda signature always include 'self'
+ cpp_args = cpp.arguments(
+ arguments=schema.arguments,
+ faithful=False,
+ symint=symint,
+ method=False,
+ cpp_no_default_args=f.cpp_no_default_args,
+ )
+ out_args: Set[str] = {a.name for a in schema.arguments.out}
+
+ # Convert from cpp argument to lambda argument
+ def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
+ type_str = cpp_arg.type
+ is_out_arg = cpp_arg.name in out_args
+ if ps.method and cpp_arg.name == "self":
+ # For method's 'self', we can use 'const Tensor &' and simply ignore mutability!
+ type_str = "const at::Tensor &"
+ else:
+ # For other cases we need prevent dangling refs to temps (unless it's
+ # unpacked scattered output)
+ # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'.
+ # TODO: avoid this special handling?
+ ensure_temp_safe = len(out_args) <= 1 or not is_out_arg
+ if ensure_temp_safe:
+ type_str = {
+ "at::Tensor &": "at::Tensor",
+ }.get(type_str, type_str)
+ return DispatchLambdaArgument(
+ name=cpp_arg.name,
+ type_str=type_str,
+ is_out_arg=is_out_arg,
+ )
+
+ return tuple(map(dispatch_lambda_arg, cpp_args))
+
+
+# [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean
+# it's enough to just extend the list here. Before you do this, make sure
+# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
+SUPPORTED_RETURN_TYPES = {
+ "at::Tensor",
+ "::std::tuple",
+ "::std::tuple",
+ "::std::tuple",
+ "::std::tuple",
+ "::std::tuple",
+ "::std::tuple",
+ "::std::tuple",
+ "::std::tuple",
+ "::std::tuple",
+ "::std::tuple",
+ "::std::tuple>",
+ "::std::vector",
+ # Needed for flash attention forw/backward
+ "::std::tuple",
+ "at::Scalar",
+ "bool",
+ "int64_t",
+ "void*",
+ "void",
+ "at::QScheme",
+ "double",
+ "at::IntArrayRef",
+ "at::ScalarType",
+ "at::Stream",
+}
+
+
+def dispatch_lambda_return_str(f: NativeFunction) -> str:
+ # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &')
+ # because the dispatch lambdas take mutable arguments *by value*, not
+ # by reference. If you then return a reference to such an argument, you
+ # will now have a pointer to a dangling stack entry. Not good.
+ #
+ # You want:
+ #
+ # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); };
+ # ^^^^^^
+ #
+ # *not*
+ #
+ # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); };
+ # ^^^^^^^
+ #
+ # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
+ # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a
+ # mutable reference to temporary. Maybe we could assign it to a
+ # variable itself.)
+ returns_without_annotation = tuple(
+ Return(r.name, r.type, None) for r in f.func.returns
+ )
+ return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
+ if return_str not in SUPPORTED_RETURN_TYPES:
+ raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}")
+ return return_str
+
+
+def cpp_dispatch_target(f: NativeFunction) -> str:
+ symint = f.func.has_symint()
+ name = cpp.name(f.func, symint_overload=symint)
+ if Variant.method in f.variants:
+ return f"self.{name}"
+ if Variant.function in f.variants:
+ if has_tensor_options(f) or f.func.name.name.base.endswith("_like"):
+ namespace = "torch"
+ else:
+ namespace = "at"
+ return f"{namespace}::{name}"
+ raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}")
+
+
+def cpp_dispatch_exprs(
+ f: NativeFunction,
+ *,
+ python_signature: Optional[PythonSignature] = None,
+) -> Tuple[str, ...]:
+ cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
+
+ exprs: Tuple[str, ...] = tuple()
+ if not isinstance(python_signature, PythonSignatureDeprecated):
+ # By default the exprs are consistent with the C++ signature.
+ exprs = tuple(a.name for a in cpp_args)
+ else:
+ # For deprecated python signature we may need fill in some constants.
+ exprs = tuple(
+ filter(
+ lambda n: n != "out" or f.func.is_out_fn(),
+ python_signature.deprecated_args_exprs,
+ )
+ )
+
+ if Variant.method in f.variants:
+ exprs = tuple(filter("self".__ne__, exprs))
+
+ return exprs
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# Python / C++ Args Binding
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+# We explicitly enumerate the PythonArgParser unpacking methods for all
+# supported types. This might be more verbose than necessary, partially
+# because of the irregularity of unpacking method naming, partially
+# because we want to mimic the old codegen behavior - to reject
+# unexpected and/or unsupported cases which the old codegen rejects.
+# For certain cases it is intentionally more restrictive than necessary,
+# e.g.: it doesn't accepts doublelist with definite size.
+def arg_parser_unpack_method(
+ t: Type, default: Optional[str], default_init: Optional[str], *, symint: bool = True
+) -> str:
+ has_default_init = default_init is not None
+ if has_default_init and str(t) not in (
+ "ScalarType?",
+ "ScalarType",
+ "Device",
+ "Device?",
+ "Layout",
+ "Layout?",
+ "bool",
+ "bool?",
+ ):
+ raise RuntimeError(f"type '{t}' does not supported unpacking with default")
+
+ if isinstance(t, BaseType):
+ if t.name in [
+ BaseTy.Tensor,
+ BaseTy.Stream,
+ BaseTy.Storage,
+ BaseTy.Scalar,
+ BaseTy.Dimname,
+ ]:
+ # These unpack methods line up with their schema names
+ return t.name.name.lower()
+ elif t.name == BaseTy.ScalarType:
+ return "scalartypeWithDefault" if has_default_init else "scalartype"
+ elif t.name == BaseTy.Device:
+ return "deviceWithDefault" if has_default_init else "device"
+ elif t.name == BaseTy.DeviceIndex:
+ return "toInt64"
+ elif t.name == BaseTy.int:
+ return "toInt64"
+ elif t.name == BaseTy.SymInt:
+ return "toSymInt" if symint else "toInt64"
+ elif t.name == BaseTy.bool:
+ return "toBoolWithDefault" if has_default_init else "toBool"
+ elif t.name == BaseTy.float:
+ return "toDouble"
+ elif t.name == BaseTy.str:
+ return "stringView"
+ elif t.name == BaseTy.Layout:
+ return "layoutWithDefault" if has_default_init else "layout"
+ elif t.name == BaseTy.MemoryFormat:
+ return "memoryformat"
+
+ elif isinstance(t, OptionalType):
+ if str(t.elem) == "Tensor":
+ return "optionalTensor"
+ elif str(t.elem) == "Generator":
+ return "generator"
+ elif str(t.elem) == "Dimname[]":
+ return "toDimnameListOptional"
+ elif not has_default_init and default in (None, "None", "c10::nullopt"):
+ # If default is None: append 'Optional' to elem's unpacking method
+ return (
+ arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
+ )
+ else:
+ # Otherwise, load as underlying type with default
+ return arg_parser_unpack_method(
+ t.elem, default, default_init, symint=symint
+ )
+
+ elif isinstance(t, ListType):
+ if str(t.elem) == "Tensor":
+ # accept and use definite size
+ return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist"
+ elif str(t.elem) == "Tensor?":
+ return "list_of_optional_tensors"
+ elif str(t.elem) == "Dimname":
+ # accept definite size
+ return "dimnamelist"
+ elif str(t.elem) == "int":
+ # accept definite size
+ return "intlist"
+ elif str(t.elem) == "float":
+ return "doublelist"
+ elif str(t.elem) == "SymInt":
+ # accept definite size
+ return "symintlist" if symint else "intlist"
+ elif str(t.elem) == "Scalar":
+ return "scalarlist"
+ raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")
+
+
+# Return RHS expression for python argument using PythonArgParser output.
+# e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)'
+def arg_parser_output_expr(
+ arg_index: int, a: PythonArgument, *, symint: bool = True
+) -> PythonArgParserOutputExpr:
+ has_default = a.default_init is not None
+ unpack_method = arg_parser_unpack_method(
+ t=a.type, default=a.default, default_init=a.default_init, symint=symint
+ )
+ default = f", {a.default_init}" if has_default else ""
+ expr = f"_r.{unpack_method}({arg_index}{default})"
+
+ return PythonArgParserOutputExpr(
+ name=a.name,
+ expr=expr,
+ index=arg_index,
+ argument=a,
+ )
+
+
+# Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
+def arg_parser_output_exprs(
+ ps: PythonSignature, f: NativeFunction, *, symint: bool = True
+) -> Dict[str, PythonArgParserOutputExpr]:
+ return {
+ e.name: e
+ for i, a in enumerate(ps.arguments())
+ for e in (arg_parser_output_expr(i, a, symint=symint),)
+ }
+
+
+# argument name to type for scattered tensor options fields
+TENSOR_OPTIONS_FIELDS = {
+ "dtype": "ScalarType?",
+ "device": "Device?",
+ "layout": "Layout?",
+ "pin_memory": "bool?",
+ "requires_grad": "bool?",
+}
+
+
+# bind arg parser outputs (python args) with dispatch lambda arguments (c++ args).
+def dispatch_lambda_exprs(
+ ps: PythonSignature, f: NativeFunction, *, symint: bool = True
+) -> DispatchLambdaArgumentExprs:
+ # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing
+ # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser
+ # outputs.
+ arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
+ lambda_args = dispatch_lambda_args(ps, f, symint=symint)
+ inits: List[str] = []
+ lambda_args_exprs: Dict[str, str] = {}
+
+ has_toptions = has_tensor_options(f)
+
+ # 1. special inits/unpacking to provide binding exprs for lambda arguments.
+ for a in ps.arguments(skip_tensor_options=True):
+ name = a.name
+ arg_parser_expr = arg_parser_outputs[a.name].expr
+
+ if has_toptions and name == "self":
+ # TODO: why this needs to be special case?
+ inits.extend(
+ [
+ f"auto self = {arg_parser_expr};",
+ ]
+ )
+ lambda_args_exprs[name] = name
+ elif (
+ isinstance(a, PythonOutArgument)
+ and len(a.outputs) > 1
+ and f.func.is_out_fn()
+ ):
+ inits.extend(
+ [
+ f"auto out = {arg_parser_expr};",
+ ]
+ )
+ for i, out_arg in enumerate(a.outputs):
+ lambda_args_exprs[out_arg.name] = f"out[{i}]"
+ elif str(a.type) == "Dimname[]?":
+ # [old codegen]
+ # TODO: make this part of something more general, or get rid of it.
+ # optional> are special. The PythonArgParser returns an
+ # optional>, which cannot be implicitly converted to
+ # optional>. One needs to unwrap the optional and rewrap.
+ inits.extend(
+ [
+ f"auto __{name} = {arg_parser_expr};",
+ f"c10::optional {name} = __{name} ? c10::make_optional(DimnameList(__{name}.value())) : c10::nullopt;", # noqa: B950
+ ]
+ )
+ lambda_args_exprs[name] = name
+ else:
+ # default case - directly using PythonArgParser output expr
+ lambda_args_exprs[name] = arg_parser_expr
+
+ # method's self is passed directly to python binding, rather than parsed
+ if ps.method:
+ lambda_args_exprs["self"] = "self"
+
+ # 2. special packing/checking for TensorOptions.
+ tensor_options_args_names = [a.name for a in ps.tensor_options_args]
+ if has_toptions:
+ if f.func.is_out_fn():
+ raise RuntimeError(f"{f.func}: tensor options with output arg")
+ for a in ps.tensor_options_args:
+ if a.name not in TENSOR_OPTIONS_FIELDS:
+ raise RuntimeError(
+ f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments"
+ )
+ if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name):
+ raise RuntimeError(
+ f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
+ )
+ if not all(
+ a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS.keys()
+ ):
+ raise RuntimeError(
+ f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
+ )
+
+ inits.append(
+ f"""\
+const auto options = TensorOptions()
+ .dtype({arg_parser_outputs['dtype'].expr})
+ .device({arg_parser_outputs['device'].expr})
+ .layout({arg_parser_outputs['layout'].expr})
+ .requires_grad({arg_parser_outputs['requires_grad'].expr})
+ .pinned_memory({arg_parser_outputs['pin_memory'].expr});
+torch::utils::maybe_initialize_device(options);
+"""
+ )
+ lambda_args_exprs["options"] = "options"
+
+ # 3. special case - access scattered TensorOptions fields without packing
+ # TODO: maybe move to the generator side as it's not related to binding.
+ if not has_toptions and tensor_options_args_names:
+ if "dtype" in tensor_options_args_names:
+ # we're an output-arg variant, check these args against output tensor
+ if not f.func.is_out_fn():
+ raise RuntimeError(
+ f"{f.func}: dtype in tensor_options_args without output arg"
+ )
+ if not all(a in tensor_options_args_names for a in ("layout", "device")):
+ raise RuntimeError(
+ f"{f.func}: incomplete tensor options for output check"
+ )
+
+ inits.append(
+ f"""\
+check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dtype'].expr},
+ {arg_parser_outputs['dtype'].is_none_expr}, {arg_parser_outputs['layout'].expr},
+ {arg_parser_outputs['device'].expr}, {arg_parser_outputs['device'].is_none_expr});
+"""
+ )
+ # we'll set requires_grad on outgoing tensor
+ if "requires_grad" not in tensor_options_args_names:
+ raise RuntimeError(
+ f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]'
+ )
+
+ return DispatchLambdaArgumentExprs(
+ exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args),
+ inits=inits,
+ )
diff --git a/MLPY/Lib/site-packages/torchgen/api/structured.py b/MLPY/Lib/site-packages/torchgen/api/structured.py
new file mode 100644
index 0000000000000000000000000000000000000000..10a83b65d9dcbe9211c4f5cbd2b16f4f3f1506ba
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/structured.py
@@ -0,0 +1,157 @@
+from typing import List, Union
+
+from torchgen.api import cpp
+
+from torchgen.api.types import (
+ ArgName,
+ ArrayRefCType,
+ BaseCType,
+ Binding,
+ ConstRefCType,
+ dimnameListT,
+ intArrayRefT,
+ iOptTensorListRefT,
+ iTensorListRefT,
+ NamedCType,
+ OptionalCType,
+ optionalIntArrayRefT,
+ optionalScalarRefT,
+ optionalTensorRefT,
+ scalarT,
+ tensorT,
+)
+from torchgen.model import (
+ Argument,
+ BaseTy,
+ BaseType,
+ ListType,
+ NativeFunctionsGroup,
+ OptionalType,
+ SelfArgument,
+ TensorOptionsArguments,
+ Type,
+)
+from torchgen.utils import assert_never
+
+# This file describes the translation of JIT schema to the structured functions API.
+# This is similar to native API, but a number of historical problems with native
+# API have been fixed.
+
+
+# Translation of types occurring in JIT arguments to a C++ argument type.
+# NB: For now, mutable doesn't do anything; but it could if we make
+# some more nominal types
+def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
+ # If it's a value type, do the value type translation
+ # NB: structured kernels ALWAYS have symint off, since they involve actual
+ # kernels that require real ints. The one exception is the
+ # CompositeExplicitAutograd and the meta function (which could
+ # hypothetically be SymInt), but for simplicity we plan for these to just
+ # be handled in Python
+ r = cpp.valuetype_type(t, symint=False, binds=binds)
+ if r is not None:
+ return r
+
+ if isinstance(t, BaseType):
+ if t.name == BaseTy.Tensor:
+ return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
+ elif t.name == BaseTy.Scalar:
+ return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
+ else:
+ raise AssertionError(f"base type should have been value type {t}")
+ elif isinstance(t, OptionalType):
+ if t.elem == BaseType(BaseTy.Tensor):
+ return NamedCType(binds, BaseCType(optionalTensorRefT))
+ elif t.elem == BaseType(BaseTy.Scalar):
+ return NamedCType(binds, BaseCType(optionalScalarRefT))
+ elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
+ return NamedCType(binds, BaseCType(optionalIntArrayRefT))
+ elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
+ return NamedCType(binds, OptionalCType(elem.type))
+ elif isinstance(t, ListType):
+ if t.elem == BaseType(BaseTy.Tensor):
+ return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
+ elif t.elem == OptionalType(BaseType(BaseTy.Tensor)):
+ return NamedCType(binds, BaseCType(iOptTensorListRefT))
+ # TODO: delete these special cases; see torchgen.api.cpp--these
+ # must be changed in tandem, but there are problems; see
+ # https://github.com/pytorch/pytorch/pull/51485
+ elif str(t.elem) == "int":
+ return NamedCType(binds, BaseCType(intArrayRefT))
+ elif str(t.elem) == "Dimname":
+ return NamedCType(binds, BaseCType(dimnameListT))
+ elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
+ return NamedCType(binds, ArrayRefCType(elem.type))
+ else:
+ raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
+ return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
+
+
+# returns_type intentionally omitted, because structured kernels never "return";
+# instead, they always indirectly report their outputs (in the case of a meta
+# function, by calling set_output; in the case of an impl function, by writing
+# directly into the provided out argument).
+
+
+# Structured kernels are never defaulted
+def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]:
+ if isinstance(a, Argument):
+ return [
+ Binding(
+ nctype=argument_type(a, binds=a.name),
+ name=a.name,
+ default=None,
+ argument=a,
+ )
+ ]
+ elif isinstance(a, SelfArgument):
+ return argument(a.argument)
+ elif isinstance(a, TensorOptionsArguments):
+ raise AssertionError("structured kernels don't support TensorOptions yet")
+ else:
+ assert_never(a)
+
+
+def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]:
+ args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
+
+ if g.out.precomputed:
+ # A list of parameters for the impl function with
+ # certain parameters replaced with precomputed counterparts
+ # as specified in native_functions.yaml.
+ non_out_args_replaced: List[
+ Union[Argument, TensorOptionsArguments, SelfArgument]
+ ] = []
+ for a in g.out.func.arguments.non_out:
+ if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
+ # If a is in precompute.replace, append the parameters
+ # that should replace it onto non_out_args_replaced.
+ non_out_args_replaced.extend(g.out.precomputed.replace[a.name])
+ else:
+ # If not, push a as it is.
+ non_out_args_replaced.append(a)
+
+ args.extend(non_out_args_replaced)
+ # g.out.precomputed.add is the list of parameters that are added
+ # without replacement after the non out args and just before the out args
+ args.extend(g.out.precomputed.add)
+ else:
+ args.extend(g.out.func.arguments.non_out)
+
+ args.extend(g.out.func.arguments.out)
+ return [r for arg in args for r in argument(arg)]
+
+
+def meta_arguments(g: NativeFunctionsGroup) -> List[Binding]:
+ args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
+ args.extend(g.functional.func.arguments.non_out)
+ return [r for arg in args for r in argument(arg)]
+
+
+def out_arguments(g: NativeFunctionsGroup) -> List[Binding]:
+ args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
+ args.extend(g.out.func.arguments.out)
+ return [r for arg in args for r in argument(arg)]
diff --git a/MLPY/Lib/site-packages/torchgen/api/translate.py b/MLPY/Lib/site-packages/torchgen/api/translate.py
new file mode 100644
index 0000000000000000000000000000000000000000..7824446f4b6018f0a6eb707438553dc453d43e54
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/translate.py
@@ -0,0 +1,430 @@
+from typing import Dict, List, NoReturn, Sequence, Union
+
+from torchgen.api.types import (
+ ArrayRefCType,
+ BaseCType,
+ Binding,
+ boolT,
+ ConstRefCType,
+ deviceT,
+ Expr,
+ intArrayRefT,
+ iOptTensorListRefT,
+ layoutT,
+ ListCType,
+ longT,
+ memoryFormatT,
+ MutRefCType,
+ NamedCType,
+ opmath_t,
+ OptionalCType,
+ optionalIntArrayRefT,
+ optionalScalarRefT,
+ optionalSymIntArrayRefT,
+ optionalTensorRefT,
+ scalar_t,
+ scalarT,
+ scalarTypeT,
+ SpecialArgName,
+ symIntArrayRefT,
+ SymIntT,
+ tensorOptionsT,
+ tensorT,
+ VectorCType,
+)
+
+# This file implements a small program synthesis engine that implements
+# conversions between one API to another.
+#
+# The key data type in this file in NamedCType, short for Named C++ semantic type. A NamedCType
+# represents a C++ type, plus semantic information about what it represents.
+# For example, consider the argument "bool pin_memory"; its normal C++ type is
+# "bool", but its C++ semantic type also keeps track that this represents a
+# "pin_memory"; you can't just use a random other boolean in a context where you
+# need a "pin_memory"!
+#
+# The translator takes a list of needed NamedCTypes, and then figures out how
+# to construct expressions with these NamedCTypes from the given bindings. Many
+# of these expressions are trivial (I need a Tensor other; there's a Tensor
+# other scope); others are more nontrivial and may require packing/unpacking.
+# Some examples of non-trivial action:
+#
+# - Need the "dtype" binding? Well, maybe "dtype" isn't available
+# in the context, instead, "options" is, and you need to extract
+# it from there. (Gather)
+#
+# - Need the "context" binding? Well, maybe "context" isn't available
+# in the context, and you need to construct it from "dtype", "device",
+# etc. (Scatter)
+#
+# - Need the "memory_format" binding? Well, actually, it's available
+# from both "memory_format" and "options", so you had better make sure
+# they are consistent. (Join)
+
+options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))
+
+out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT)))
+
+longVec_ctype = VectorCType(BaseCType(longT))
+longSymVec_ctype = VectorCType(BaseCType(SymIntT))
+optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
+optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
+optionalTensor_ctype = OptionalCType(BaseCType(tensorT))
+
+
+class UnsatError(RuntimeError):
+ pass
+
+
+# Given a set of in-scope bindings and a set of target bindings, synthesize
+# a list of expressions that uses only the in-scope bindings (bindings) that
+# have all of the types of goals. You may want to use this function if
+# you're generating code for a function like:
+#
+# void f({args}) {
+# g({exprs}); // g is a different API
+# }
+#
+# and you need to generate "exprs".
+#
+# Typically, a list of Bindings is convenient to get (you usually call something
+# like arguments() to get them); but technically you only need less information:
+# for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for
+# 'goals', an (ordered) list of NamedCType goals is sufficient. If you are doing
+# something more complicated, e.g., tracking the set of bindings in a context,
+# you may find using these smaller types more convenient.
+def translate(
+ bindings: Sequence[Union[Expr, Binding]],
+ goals: Sequence[Union[NamedCType, Binding]],
+ *,
+ method: bool = False,
+ allow_expensive_conversions: bool = False,
+) -> List[Expr]:
+ binding_exprs: List[Expr] = []
+ for b in bindings:
+ if isinstance(b, Binding):
+ binding_exprs.append(
+ Expr(
+ expr=b.name,
+ type=b.nctype,
+ )
+ )
+ else:
+ binding_exprs.append(b)
+
+ goal_ctypes: List[NamedCType] = []
+ for g in goals:
+ if isinstance(g, Binding):
+ goal_ctypes.append(g.nctype)
+ else:
+ goal_ctypes.append(g)
+
+ # Add all the bindings to the context
+ ctx: Dict[NamedCType, str] = {}
+ for b in binding_exprs:
+ ctx[b.type] = b.expr
+
+ # While we're at it, do some simple forward inference, looking through
+ # constructors.
+ #
+ # NB: When should you do forward inference versus backward inference?
+ # The general idea:
+ #
+ # - Backward inference WHEN the goal gets smaller
+ # - Forward inference WHEN the hypothesis gets smaller
+ #
+ # This helps ensure termination: backward inference starts with a goal
+ # and tries to make it simpler and simpler until it's trivial; if the
+ # goal can grow in size, we blow up to a really huge goal size.
+ # Similarly, with forward inference we take hypotheses and decompose
+ # them into simpler hypotheses; if hypotheses could expand in size,
+ # we also have potential nontermination. (In the code below, forward
+ # inference is only ever carried out at a single step, but you could
+ # imagine repeated application of forward inference being profitable.)
+ #
+ # A good starting point in the literature for exploring more about proof
+ # search are these lecture notes
+ # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf
+ #
+ # TODO: My kingdom for a pattern matcher
+ # https://www.python.org/dev/peps/pep-0634/
+ #
+ # TODO: This could get us in recomputation trouble if b.expr is nontrivial.
+ # Fix this by implementing some sort of sharing so that if multiple
+ # goals share the same expression, we only compute it once. This seems
+ # to matter in practice as compiler is often unwilling to CSE nontrivial
+ # expressions like scalar.to()
+ t = b.type
+ if (
+ isinstance(t, ConstRefCType)
+ and isinstance(t.elem, OptionalCType)
+ and isinstance(t.elem.elem, BaseCType)
+ and str(t.elem.elem.type) == "at::Tensor"
+ ):
+ ctx[
+ NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))
+ ] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())"
+
+ if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
+ ctx[
+ NamedCType(t.name, BaseCType(optionalTensorRefT))
+ ] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())"
+
+ if t.type == ConstRefCType(BaseCType(scalarT)):
+ ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to()"
+
+ if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))):
+ ctx[
+ NamedCType(t.name, BaseCType(optionalScalarRefT))
+ ] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())"
+
+ if t.type == BaseCType(scalar_t):
+ ctx[
+ NamedCType(t.name, BaseCType(opmath_t))
+ ] = f"static_cast({b.expr})"
+
+ # [Note: IOptTensorListRef]
+ if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))):
+ ctx[
+ NamedCType(t.name, BaseCType(iOptTensorListRefT))
+ ] = f"at::IOptTensorListRef({b.expr})"
+
+ # Add implicit bindings if the generated code is inside a Tensor method
+ if method:
+ ctx[
+ NamedCType("self", MutRefCType(BaseCType(tensorT)))
+ ] = "const_cast(*this)"
+ ctx[
+ NamedCType("self", ConstRefCType(BaseCType(tensorT)))
+ ] = "const_cast(*this)"
+ # This is better! Byte-for-byte compat
+ # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"
+
+ def unsat(goal: NamedCType) -> NoReturn:
+ ctx_desc = "\n".join(
+ f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items()
+ )
+ raise UnsatError(
+ f"""
+Failed to synthesize the expression "{goal.cpp_type()} {goal.name}".
+When I failed, the following bindings were available in the context:
+
+{ctx_desc}
+
+This probably means there is a missing rule in the rules of torchgen.api.translate.
+Check this module for more information.
+"""
+ )
+
+ # A shitty backtracking search implementation. It's shitty because it
+ # does backtracking via stack (bad idea!) and for the most part tries to
+ # avoid backtracking. In particular, if
+ # direct=True, we won't try to do any fancy synthesis, just trivial
+ # conversions (e.g., "T a" is OK for "const T& a"). So all of the
+ # existing rules in this function simply try to solve immediately,
+ # and bail if things don't work out.
+ def solve(goal: NamedCType, *, direct: bool) -> str:
+ def direct_solve(goal: NamedCType) -> str:
+ return solve(goal, direct=True)
+
+ if goal in ctx:
+ # Trivial
+ return ctx[goal]
+
+ # const & is satisfied with mutable &
+ if isinstance(goal.type, ConstRefCType):
+ try:
+ # WARNING: not strictly decreasing; be careful not
+ # to add a direct conversion that goes satisfies
+ # mutable& with const&
+ return solve(
+ NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct
+ )
+ except UnsatError:
+ pass
+
+ # mutable & is satisfied with value
+ if isinstance(goal.type, MutRefCType):
+ try:
+ return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
+ except UnsatError:
+ pass
+
+ # TODO: These are referentially equal, shouldn't have to do this;
+ # ensuring we don't use type synonym IntArrayRef in codegen would
+ # help
+ if goal.type == ArrayRefCType(BaseCType(longT)):
+ return solve(NamedCType(goal.name, BaseCType(intArrayRefT)), direct=direct)
+
+ if direct:
+ unsat(goal)
+
+ # For now, all of these rules are mutually exclusive.
+ if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
+ memory_format = direct_solve(
+ NamedCType(
+ SpecialArgName.possibly_redundant_memory_format,
+ OptionalCType(BaseCType(memoryFormatT)),
+ )
+ )
+ # No need to join "memory_format" and "options" if the target API takes "options" directly.
+ # Otherwise it will cause the redundant memory_format error.
+ if options_ctype in goal_ctypes:
+ return memory_format
+ try:
+ options = direct_solve(options_ctype)
+ return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
+ except UnsatError:
+ return memory_format
+ elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
+ dtype = direct_solve(
+ NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT)))
+ )
+ pin_memory = direct_solve(
+ NamedCType("pin_memory", OptionalCType(BaseCType(boolT)))
+ )
+ device = direct_solve(
+ NamedCType("device", OptionalCType(BaseCType(deviceT)))
+ )
+ layout = direct_solve(
+ NamedCType("layout", OptionalCType(BaseCType(layoutT)))
+ )
+ return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})"
+
+ elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
+ try:
+ options = direct_solve(options_ctype)
+ return f"c10::optTypeMetaToScalarType({options}.dtype_opt())"
+ except UnsatError:
+ out_tensor = direct_solve(out_tensor_ctype)
+ return f"{out_tensor}.scalar_type()"
+
+ elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
+ try:
+ options = direct_solve(options_ctype)
+ return f"{options}.layout_opt()"
+ except UnsatError:
+ out_tensor = direct_solve(out_tensor_ctype)
+ return f"{out_tensor}.layout()"
+
+ elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
+ try:
+ options = direct_solve(options_ctype)
+ return f"{options}.device_opt()"
+ except UnsatError:
+ out_tensor = direct_solve(out_tensor_ctype)
+ return f"{out_tensor}.device()"
+
+ elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
+ try:
+ options = direct_solve(options_ctype)
+ return f"{options}.pinned_memory_opt()"
+ except UnsatError:
+ # If we're calling a factory op from its out= variant,
+ # We don't actually care about the value of pin_memory.
+ out_tensor = direct_solve(out_tensor_ctype)
+ return "c10::nullopt"
+
+ # We can always do translations from value types to reference types, like vector -> IntArrayRef
+ elif goal.type == BaseCType(intArrayRefT):
+ try:
+ return direct_solve(NamedCType(goal.name, longVec_ctype))
+ except UnsatError:
+ # We can also go SymIntArrayRef -> IntArrayRef
+ symIntArrayRef_type = direct_solve(
+ NamedCType(goal.name, BaseCType(symIntArrayRefT))
+ )
+ return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})"
+ elif goal.type == BaseCType(symIntArrayRefT):
+ try:
+ r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT)))
+ return f"c10::fromIntArrayRefSlow({r})"
+ except UnsatError:
+ return direct_solve(NamedCType(goal.name, longSymVec_ctype))
+ elif goal.type == BaseCType(SymIntT):
+ return direct_solve(NamedCType(goal.name, BaseCType(longT)))
+ elif goal.type == OptionalCType(BaseCType(SymIntT)):
+ argname = direct_solve(
+ NamedCType(goal.name, OptionalCType(BaseCType(longT)))
+ )
+ return f"{argname}.has_value() ? c10::make_optional(c10::SymInt(*{argname})) : c10::nullopt"
+ elif goal.type == BaseCType(longT):
+ symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT)))
+ return f"{symInt_type}.guard_int(__FILE__, __LINE__)"
+ elif goal.type == OptionalCType(BaseCType(longT)):
+ argname = direct_solve(
+ NamedCType(goal.name, OptionalCType(BaseCType(SymIntT)))
+ )
+ return f"{argname}.has_value() ? c10::make_optional({argname}->guard_int(__FILE__, __LINE__)) : c10::nullopt"
+ elif goal.type == BaseCType(optionalIntArrayRefT):
+ try:
+ return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
+ except UnsatError:
+ argname = direct_solve(
+ NamedCType(goal.name, BaseCType(optionalSymIntArrayRefT))
+ )
+ return f"{argname}.has_value() ? c10::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : c10::nullopt"
+ elif goal.type == BaseCType(optionalSymIntArrayRefT):
+ # TODO: You might also want to solve this from longSymVec_ctype or
+ # an optional version of it
+ argname = direct_solve(
+ NamedCType(goal.name, BaseCType(optionalIntArrayRefT))
+ )
+ return f"{argname}.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*{argname})) : c10::nullopt"
+ elif goal.type == BaseCType(optionalScalarRefT):
+ return direct_solve(NamedCType(goal.name, optionalScalar_ctype))
+ elif goal.type == BaseCType(optionalTensorRefT):
+ return direct_solve(NamedCType(goal.name, optionalTensor_ctype))
+
+ # Note [translation from C++ reference to value types]
+ # The below cases are all for when we have an argument with a reference type,
+ # and a corresponding goal with a value type.
+ # These are needed when we populate the inputs to a lambda capture and we need
+ # to guarantee the lifetime of each captured argument.
+ # We guard it with an explicit kwarg because converting to a value type is expensive
+ # (O(n)) to convert from IntArrayRef to vector),
+ # so the caller of translate() should be explicit that they need it.
+ if allow_expensive_conversions:
+ if goal.type == VectorCType(BaseCType(longT)):
+ intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT))
+ argname = direct_solve(intArrayRef_ctype)
+ return f"{argname}.vec()"
+ if goal.type == VectorCType(BaseCType(SymIntT)):
+ symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT))
+ argname = direct_solve(symIntArrayRef_ctype)
+ return f"{argname}.vec()"
+ elif goal.type == OptionalCType(VectorCType(BaseCType(longT))):
+ optionalIntArrayRef_ctype = NamedCType(
+ goal.name, BaseCType(optionalIntArrayRefT)
+ )
+ argname = direct_solve(optionalIntArrayRef_ctype)
+ return f"{argname}.has_value() ? c10::make_optional({argname}->vec()) : c10::nullopt"
+ elif goal.type == OptionalCType(BaseCType(scalarT)):
+ optionalScalarRef_ctype = NamedCType(
+ goal.name, BaseCType(optionalScalarRefT)
+ )
+ argname = direct_solve(optionalScalarRef_ctype)
+ return f"{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt"
+ elif goal.type == OptionalCType(BaseCType(scalarT)):
+ optionalTensorRef_ctype = NamedCType(
+ goal.name, BaseCType(optionalTensorRefT)
+ )
+ argname = direct_solve(optionalTensorRef_ctype)
+ return f"{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt"
+ # Technically, we also need to handle cases of C++ containers holding reference types.
+ # But there currently aren't any ops that require lambda capture codegen
+ # With arguments like std::vector.
+ # If that changes, we'll have to add the translation here.
+
+ # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor.
+ # We could probably generalize this to non-tensor types too.
+ if goal.type == MutRefCType(BaseCType(tensorT)):
+ const_ref_tensor_ctype = NamedCType(
+ goal.name, ConstRefCType(BaseCType(tensorT))
+ )
+ argname = direct_solve(const_ref_tensor_ctype)
+ return f"const_cast({argname})"
+
+ unsat(goal)
+
+ return [Expr(solve(g, direct=False), g) for g in goal_ctypes]
diff --git a/MLPY/Lib/site-packages/torchgen/api/types/__init__.py b/MLPY/Lib/site-packages/torchgen/api/types/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..19ba90c31f9e12d0ceee8850ed2003f3c87b4e1b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/types/__init__.py
@@ -0,0 +1,3 @@
+from .types import *
+from .types_base import *
+from .signatures import * # isort:skip
diff --git a/MLPY/Lib/site-packages/torchgen/api/types/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/types/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8dc9af1fdc571c17112d7854c2f22b7bafa2b061
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/types/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/types/__pycache__/signatures.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/types/__pycache__/signatures.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..551b7594110e252b4792d255222d4a0c63b2ca87
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/types/__pycache__/signatures.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/types/__pycache__/types.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/types/__pycache__/types.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3b56056eb98c774131b25dbb18c658e6380ad2c
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/types/__pycache__/types.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/types/__pycache__/types_base.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/api/types/__pycache__/types_base.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2623b96b2371f683296cb5f5106fbe932cb01eb6
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/api/types/__pycache__/types_base.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/api/types/signatures.py b/MLPY/Lib/site-packages/torchgen/api/types/signatures.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5716fea645da0799d3e994899be69a6086b28ab
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/types/signatures.py
@@ -0,0 +1,423 @@
+from dataclasses import dataclass
+
+from typing import Iterator, List, Optional, Sequence, Set, Tuple, Union
+
+from torchgen.model import (
+ BackendIndex,
+ FunctionSchema,
+ NativeFunction,
+ NativeFunctionsGroup,
+ NativeFunctionsViewGroup,
+)
+
+from .types_base import Binding, CType, Expr
+
+
+@dataclass(frozen=True)
+class CppSignature:
+ """
+ A CppSignature represents a single overload in the C++ API. For
+ any given function schema, there may be multiple CppSignatures
+ corresponding to it, based on how we desugar to C++. See also
+ CppSignatureGroup.
+ """
+
+ # The schema this signature is derived from
+ func: FunctionSchema
+
+ # Is this a C++ signature for a method, i.e. Tensor::my_op(...)?
+ method: bool
+
+ # Is this a faithful C++ signature (i.e. following the JIT schema) or a convenience API
+ # (i.e. with a potential TensorOptions argument and out arguments in the front)
+ faithful: bool
+
+ # Is this a symint C++ signature. For BC reasons, functions that take
+ # SymInts still present as int64_t in C++, and the SymInt variant is
+ # offered at a different overload name
+ #
+ # NB: If a function RETURNS a SymInt, this is ALWAYS false
+ symint: bool
+
+ # The set of C++ arguments which should not have defaults applied to them
+ cpp_no_default_args: Set[str]
+
+ # Is this a fallback C++ binding? Fallback bindings are enabled by
+ # manual_cpp_binding: True and are alternate, non-public API that
+ # lets manual C++ binding implementors access the binding that would
+ # have been automatically generated
+ fallback_binding: bool = False
+
+ # Return the unpacked argument structure of this signature,
+ # discarding information about which arguments are semantically
+ # related to each other.
+ def arguments(self) -> Sequence[Binding]:
+ return cpp.arguments(
+ self.func.arguments,
+ faithful=self.faithful,
+ symint=self.symint,
+ method=self.method,
+ cpp_no_default_args=self.cpp_no_default_args,
+ )
+
+ def name(self, *, suppress_symint_suffix: bool = False) -> str:
+ n = cpp.name(
+ self.func,
+ faithful_name_for_out_overloads=self.faithful,
+ symint_overload=False if suppress_symint_suffix else self.symint,
+ )
+ if self.fallback_binding:
+ n = f"__dispatch_{n}"
+ return n
+
+ # Render the C++ declaration for this signature
+ def decl(
+ self,
+ *,
+ name: Optional[str] = None,
+ prefix: str = "",
+ is_redispatching_fn: bool = False,
+ suppress_symint_suffix: bool = False,
+ ) -> str:
+ returns_type = cpp.returns_type(
+ self.func.returns, symint=self.symint
+ ).cpp_type()
+ cpp_args = [a.decl() for a in self.arguments()]
+ if is_redispatching_fn:
+ cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args
+ cpp_args_str = ", ".join(cpp_args)
+ if name is None:
+ name = prefix + self.name(suppress_symint_suffix=suppress_symint_suffix)
+ return f"{returns_type} {name}({cpp_args_str})"
+
+ # Render the C++ definition for this signature, not including
+ # the body (with curly braces)
+ def defn(
+ self,
+ *,
+ name: Optional[str] = None,
+ prefix: str = "",
+ is_redispatching_fn: bool = False,
+ ) -> str:
+ returns_type = cpp.returns_type(
+ self.func.returns, symint=self.symint
+ ).cpp_type()
+ cpp_args = [a.defn() for a in self.arguments()]
+ if is_redispatching_fn:
+ cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args
+ cpp_args_str = ", ".join(cpp_args)
+ if name is None:
+ name = prefix + self.name()
+ return f"{returns_type} {name}({cpp_args_str})"
+
+ def ptr_type(self) -> str:
+ args_types_str = ", ".join(a.type for a in self.arguments())
+ return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_types_str})"
+
+ # Return the C++ function type, e.g., something like int(bool)
+ def type(self) -> str:
+ args_types_str = ", ".join(a.type for a in self.arguments())
+ return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} ({args_types_str})"
+
+
+# Represents group of all CppSignatures associated with a
+# FunctionSchema. Right now, that's the regular, user-visible
+# signature, as well as a "faithful" signature which doesn't
+# have grouping.
+@dataclass(frozen=True)
+class CppSignatureGroup:
+ func: FunctionSchema
+ signature: CppSignature
+ faithful_signature: Optional[CppSignature]
+ symint_signature: Optional[CppSignature]
+ symint_faithful_signature: Optional[CppSignature]
+
+ def most_faithful_signature(self) -> CppSignature:
+ if self.faithful_signature:
+ return self.faithful_signature
+ else:
+ return self.signature
+
+ def signatures(self, *, symint: bool = True) -> Iterator[CppSignature]:
+ yield self.signature
+ if self.faithful_signature:
+ yield self.faithful_signature
+ if symint:
+ if self.symint_signature:
+ yield self.symint_signature
+ if self.symint_faithful_signature:
+ yield self.symint_faithful_signature
+
+ @staticmethod
+ def from_native_function(
+ f: NativeFunction, *, method: bool, fallback_binding: bool = False
+ ) -> "CppSignatureGroup":
+ func = f.func
+
+ def make_sig(*, faithful: bool, symint: bool) -> CppSignature:
+ return CppSignature(
+ func=func,
+ faithful=faithful,
+ symint=symint,
+ method=method,
+ fallback_binding=fallback_binding,
+ cpp_no_default_args=f.cpp_no_default_args,
+ )
+
+ def make_sigs(*, symint: bool) -> Tuple[CppSignature, Optional[CppSignature]]:
+ faithful_signature: Optional[CppSignature] = None
+ if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
+ faithful_signature = make_sig(faithful=True, symint=symint)
+ signature = make_sig(faithful=False, symint=symint)
+ return signature, faithful_signature
+
+ signature, faithful_signature = make_sigs(symint=False)
+ symint_signature: Optional[CppSignature] = None
+ symint_faithful_signature: Optional[CppSignature] = None
+ if func.has_symint():
+ symint_signature, symint_faithful_signature = make_sigs(symint=True)
+
+ return CppSignatureGroup(
+ func=func,
+ signature=signature,
+ faithful_signature=faithful_signature,
+ symint_signature=symint_signature,
+ symint_faithful_signature=symint_faithful_signature,
+ )
+
+
+@dataclass(frozen=True)
+class DispatcherSignature:
+ # The schema this signature is derived from
+ func: FunctionSchema
+
+ # Allows you to prepend an arbitrary prefix to the signature name.
+ # This is useful for parts of the codegen that generate wrappers around kernels,
+ # and need to avoid naming collisions.
+ prefix: str = ""
+
+ symint: bool = True
+
+ def arguments(self) -> List[Binding]:
+ return dispatcher.arguments(self.func, symint=self.symint)
+
+ def name(self) -> str:
+ return self.prefix + dispatcher.name(self.func)
+
+ def decl(self, name: Optional[str] = None) -> str:
+ args_str = ", ".join(a.decl() for a in self.arguments())
+ if name is None:
+ name = self.name()
+ return f"{self.returns_type().cpp_type()} {name}({args_str})"
+
+ def defn(
+ self, name: Optional[str] = None, *, is_redispatching_fn: bool = False
+ ) -> str:
+ args = [a.defn() for a in self.arguments()]
+ if is_redispatching_fn:
+ args = ["c10::DispatchKeySet dispatchKeySet"] + args
+ args_str = ", ".join(args)
+ if name is None:
+ name = self.name()
+ return f"{self.returns_type().cpp_type()} {name}({args_str})"
+
+ def exprs(self) -> List[Expr]:
+ return [Expr(a.name, a.nctype) for a in self.arguments()]
+
+ def returns_type(self) -> CType:
+ return dispatcher.returns_type(self.func.returns, symint=self.symint)
+
+ def ptr_type(self) -> str:
+ dispatcher_args_types_str = ", ".join(a.type for a in self.arguments())
+ return f"{self.returns_type().cpp_type()} (*)({dispatcher_args_types_str})"
+
+ # Return the C++ function type, e.g., something like int(bool)
+ def type(self) -> str:
+ dispatcher_args_types_str = ", ".join(a.type for a in self.arguments())
+ return f"{self.returns_type().cpp_type()} ({dispatcher_args_types_str})"
+
+ @staticmethod
+ def from_schema(
+ func: FunctionSchema, *, prefix: str = "", symint: bool = True
+ ) -> "DispatcherSignature":
+ return DispatcherSignature(func, prefix, symint)
+
+
+@dataclass(frozen=True)
+class NativeSignature:
+ # The schema this signature is derived from
+ func: FunctionSchema
+
+ symint: bool
+
+ prefix: str = ""
+
+ def name(self) -> str:
+ return self.prefix + native.name(self.func)
+
+ def decl(self, name: Optional[str] = None) -> str:
+ args_str = ", ".join(a.decl() for a in self.arguments())
+ if name is None:
+ name = self.name()
+ return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
+
+ def defn(self, name: Optional[str] = None) -> str:
+ args_str = ", ".join(a.defn() for a in self.arguments())
+ if name is None:
+ name = self.name()
+ return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
+
+ def ptr_type(self) -> str:
+ # don't include defaults in type signature!
+ args_str = ", ".join(a.defn() for a in self.arguments())
+ return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})"
+
+ def arguments(self) -> List[Binding]:
+ return native.arguments(self.func, symint=self.symint)
+
+ def returns_type(self) -> CType:
+ return native.returns_type(self.func.returns, symint=self.symint)
+
+ def dispatcher_exprs(self) -> List[Expr]:
+ return translate.translate(
+ self.arguments(), dispatcher.arguments(self.func), method=False
+ )
+
+
+@dataclass(frozen=True)
+class ViewInverseSignature:
+ g: NativeFunctionsViewGroup
+
+ def name(self) -> str:
+ return functionalization.reverse_name(self.g.view, include_namespace=False)
+
+ def decl(self) -> str:
+ return_type = functionalization.returns_type(self.g.view.func)
+ decls = [
+ a.decl()
+ for a in functionalization.inner_arguments(
+ self.g.view.func, is_reverse=True
+ )
+ ]
+ return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});"
+
+
+@dataclass(frozen=True)
+class FunctionalizationLambda:
+ g: NativeFunctionsViewGroup
+
+ # are we generating the forward lambda or the reverse lambda?
+ is_reverse: bool
+
+ def captures(self) -> List[Expr]:
+ # The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
+ # We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
+ # and plumb it into the lambda.
+ outer_ctx = dispatcher.arguments(self.g.view.func) + [
+ functionalization.reapply_views_binding,
+ functionalization.inverse_return_mode_binding,
+ ]
+ capture_bindings = functionalization.capture_arguments(
+ self.g.view.func, is_reverse=self.is_reverse
+ )
+ # allow_expensive_conversions is set because we want to convert
+ # some reference types (IntArrayRef) to value types (vector).
+ capture_exprs = translate.translate(
+ outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True
+ )
+ return capture_exprs
+
+ def decl(self) -> str:
+ return_type = functionalization.returns_type(self.g.view.func)
+ capture_str = ", ".join(
+ f"{val.type.name} = {val.expr}" for val in self.captures()
+ )
+ decls = [
+ a.decl()
+ for a in functionalization.outer_arguments(is_reverse=self.is_reverse)
+ ]
+ return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
+
+ def inner_call(self, *, reapply_views: Optional[bool] = None) -> str:
+ inner_call_name = functionalization.name(
+ self.g,
+ is_reverse=self.is_reverse,
+ include_namespace=True,
+ reapply_views=reapply_views,
+ )
+
+ arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse)
+ capture_ctx = functionalization.capture_arguments(
+ self.g.view.func, is_reverse=self.is_reverse
+ )
+ full_ctx = arg_ctx + capture_ctx
+
+ assert self.g.view_copy is not None
+ call_bindings = functionalization.inner_arguments(
+ self.g.view_copy.func, is_reverse=self.is_reverse
+ )
+ maybe_index = functionalization.inner_call_index(self.g.view_copy.func)
+ call_exprs = [
+ e.expr for e in translate.translate(full_ctx, call_bindings, method=False)
+ ]
+ if not self.is_reverse and maybe_index is not None:
+ return f'{inner_call_name}({", ".join(call_exprs)})[{maybe_index.name}];'
+ else:
+ return f'{inner_call_name}({", ".join(call_exprs)});'
+
+ @staticmethod
+ def from_func(
+ g: NativeFunctionsViewGroup, *, is_reverse: bool
+ ) -> "FunctionalizationLambda":
+ return FunctionalizationLambda(g, is_reverse)
+
+
+@dataclass(frozen=True)
+class StructuredImplSignature:
+ g: NativeFunctionsGroup
+ name: str
+
+ def defn(self, name: Optional[str] = None) -> str:
+ args_str = ", ".join(a.defn() for a in self.arguments())
+ return f"TORCH_IMPL_FUNC({self.name})({args_str})"
+
+ def arguments(self) -> List[Binding]:
+ return structured.impl_arguments(self.g)
+
+
+# Helper functions
+
+
+def kernel_signature(
+ f: NativeFunction, backend_index: BackendIndex, *, prefix: str = ""
+) -> Union["NativeSignature", "DispatcherSignature"]:
+ # Note [External Backends Follow Dispatcher API]
+ # Kernel signatures for in-tree backends follow the "native" API,
+ # while kernels for out-of-tree backends follow the dispatcher API.
+ # See the comments in `native.py` for details, but historically there have been
+ # some small differences in schema convention between them and the Dispatcher API.
+ # Any differences that require translating between the two will results in a runtime cost,
+ # so we'd like to keep the differences as small as possible.
+ # With external backends, we'd like to enforce that they write their kernels with schemas
+ # that match the Dispatcher API directly, if they can.
+ meta = backend_index.get_kernel(f)
+ symint = meta is not None and meta.supports_symint()
+ if symint:
+ assert (
+ f.func.has_symint()
+ ), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema"
+ if backend_index.external:
+ return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint)
+ else:
+ return NativeSignature(f.func, prefix=prefix, symint=symint)
+
+
+# Functions only, no types
+from torchgen.api import (
+ cpp,
+ dispatcher,
+ functionalization,
+ native,
+ structured,
+ translate,
+)
diff --git a/MLPY/Lib/site-packages/torchgen/api/types/types.py b/MLPY/Lib/site-packages/torchgen/api/types/types.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d686ce5418dc8cd8c637add5586555a985eddf5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/types/types.py
@@ -0,0 +1,190 @@
+"""
+Where should I add a new type? `types_base.py` vs `types.py`
+
+This file defines data model classes for torchgen typing system, as well as some base types such as int32_t.
+
+`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types.
+
+The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't
+contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused
+if we want to generate code for another C++ library.
+
+Add new types to `types.py` if these types are ATen/c10 related.
+Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
+"""
+from dataclasses import dataclass
+from typing import Dict
+
+from torchgen.model import BaseTy, ScalarType
+
+from .types_base import (
+ BaseCppType,
+ BaseCType,
+ boolT,
+ byteT,
+ charT,
+ CType,
+ doubleT,
+ floatT,
+ int32T,
+ longT,
+ shortT,
+)
+
+
+TENSOR_LIST_LIKE_CTYPES = [
+ "at::TensorList",
+ "const c10::List> &",
+ "const at::ITensorListRef &",
+]
+
+
+halfT = BaseCppType("at", "Half")
+complexHalfT = BaseCppType(
+ "c10", "complex"
+) # stuffing template param here is an abuse
+complexFloatT = BaseCppType("c10", "complex")
+complexDoubleT = BaseCppType("c10", "complex")
+bfloat16T = BaseCppType("at", "BFloat16")
+float8_e5m2T = BaseCppType("at", "Float8_e5m2")
+float8_e5m2fnuzT = BaseCppType("at", "Float8_e5m2fnuz")
+float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn")
+float8_e4m3fnuzT = BaseCppType("at", "Float8_e4m3fnuz")
+stringT = BaseCppType("c10", "string_view")
+generatorT = BaseCppType("at", "Generator")
+scalarTypeT = BaseCppType("at", "ScalarType")
+tensorT = BaseCppType("at", "Tensor")
+optionalTensorRefT = BaseCppType("at", "OptionalTensorRef")
+tensorListT = BaseCppType("at", "TensorList")
+iTensorListRefT = BaseCppType("at", "ITensorListRef")
+iOptTensorListRefT = BaseCppType("at", "IOptTensorListRef")
+dimnameT = BaseCppType("at", "Dimname")
+dimnameListT = BaseCppType("at", "DimnameList")
+dimVectorT = BaseCppType("at", "DimVector")
+layoutT = BaseCppType("at", "Layout")
+deviceT = BaseCppType("at", "Device")
+deviceIndexT = BaseCppType("at", "DeviceIndex")
+scalarT = BaseCppType("at", "Scalar")
+optionalScalarRefT = BaseCppType("at", "OptionalScalarRef")
+memoryFormatT = BaseCppType("at", "MemoryFormat")
+qschemeT = BaseCppType("at", "QScheme")
+storageT = BaseCppType("at", "Storage")
+streamT = BaseCppType("at", "Stream")
+intArrayRefT = BaseCppType("at", "IntArrayRef")
+optionalIntArrayRefT = BaseCppType("at", "OptionalIntArrayRef")
+optionalSymIntArrayRefT = BaseCppType("at", "OptionalSymIntArrayRef")
+tensorOptionsT = BaseCppType("at", "TensorOptions")
+typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize")
+tensorGeometryT = BaseCppType("at", "TensorGeometry")
+SymIntT = BaseCppType("c10", "SymInt")
+symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
+
+# Types representing template parameters. Technically, we probably shouldn't
+# represent them this way in codegen, but it was pretty convenient.
+scalar_t = BaseCppType("", "scalar_t")
+opmath_t = BaseCppType("", "opmath_t")
+
+ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = {
+ ScalarType.Byte: byteT,
+ ScalarType.Char: charT,
+ ScalarType.Short: shortT,
+ ScalarType.Int: int32T,
+ ScalarType.Long: longT,
+ ScalarType.Half: halfT,
+ ScalarType.Float: floatT,
+ ScalarType.Double: doubleT,
+ ScalarType.ComplexHalf: complexHalfT,
+ ScalarType.ComplexFloat: complexFloatT,
+ ScalarType.ComplexDouble: complexDoubleT,
+ ScalarType.Bool: boolT,
+ ScalarType.Float8_e5m2: float8_e5m2T,
+ ScalarType.Float8_e5m2fnuz: float8_e5m2fnuzT,
+ ScalarType.Float8_e4m3fn: float8_e4m3fnT,
+ ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT,
+}
+
+BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
+ BaseTy.int: longT,
+ BaseTy.float: doubleT,
+ BaseTy.bool: boolT,
+ BaseTy.str: stringT,
+ BaseTy.Generator: generatorT,
+ BaseTy.ScalarType: scalarTypeT,
+ BaseTy.Tensor: tensorT,
+ BaseTy.Dimname: dimnameT,
+ BaseTy.DimVector: dimVectorT,
+ BaseTy.Layout: layoutT,
+ BaseTy.Device: deviceT,
+ BaseTy.DeviceIndex: deviceIndexT,
+ BaseTy.Scalar: scalarT,
+ BaseTy.MemoryFormat: memoryFormatT,
+ BaseTy.QScheme: qschemeT,
+ BaseTy.Storage: storageT,
+ BaseTy.Stream: streamT,
+ BaseTy.SymInt: SymIntT,
+}
+
+# CTypes encode C++ type structure as needed for translation.
+
+
+@dataclass(frozen=True)
+class OptionalCType(CType):
+ elem: "CType"
+
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
+ # Do not pass `strip_ref` recursively.
+ return f"c10::optional<{self.elem.cpp_type()}>"
+
+ def cpp_type_registration_declarations(self) -> str:
+ return f"c10::optional<{self.elem.cpp_type_registration_declarations()}>"
+
+ def remove_const_ref(self) -> "CType":
+ return OptionalCType(self.elem.remove_const_ref())
+
+
+@dataclass(frozen=True)
+class ListCType(CType):
+ elem: "CType"
+
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
+ # Do not pass `strip_ref` recursively.
+ return f"c10::List<{self.elem.cpp_type()}>"
+
+ def cpp_type_registration_declarations(self) -> str:
+ return f"c10::List<{self.elem.cpp_type_registration_declarations()}>"
+
+ def remove_const_ref(self) -> "CType":
+ return ListCType(self.elem.remove_const_ref())
+
+
+@dataclass(frozen=True)
+class ArrayRefCType(CType):
+ elem: "CType"
+
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
+ # Do not pass `strip_ref` recursively.
+ return f"at::ArrayRef<{self.elem.cpp_type()}>"
+
+ def cpp_type_registration_declarations(self) -> str:
+ return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
+
+ def remove_const_ref(self) -> "CType":
+ return ArrayRefCType(self.elem.remove_const_ref())
+
+
+@dataclass(frozen=True)
+class VectorizedCType(CType):
+ # This template is explicitly specialized, so the only valid
+ # elems are those we have specializations for (e.g., float, double, ...)
+ # scalar_t is also a common argument here (when we are codegen in
+ # a templated context)
+ elem: BaseCType
+
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
+ return f"at::vec::Vectorized<{self.elem.cpp_type()}>"
+
+ def cpp_type_registration_declarations(self) -> str:
+ raise NotImplementedError
+
+ def remove_const_ref(self) -> "CType":
+ return self
diff --git a/MLPY/Lib/site-packages/torchgen/api/types/types_base.py b/MLPY/Lib/site-packages/torchgen/api/types/types_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..a53015f3a7f2778aad39394892f790b2cc7e2620
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/types/types_base.py
@@ -0,0 +1,270 @@
+"""
+Where should I add a new type? `types_base.py` vs `types.py`
+
+This file defines data model classes for torchgen typing system, as well as some base types such as int32_t.
+
+`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types.
+
+The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't
+contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused
+if we want to generate code for another C++ library.
+
+Add new types to `types.py` if these types are ATen/c10 related.
+Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
+"""
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from enum import auto, Enum
+from typing import List, Optional, Union
+
+from torchgen.model import Argument, SelfArgument, TensorOptionsArguments
+
+# An ArgName is just the str name of the argument in schema;
+# but in some special circumstances, we may add a little extra
+# context. The Enum SpecialArgName covers all of these cases;
+# grep for their construction sites to see when they can occur.
+
+
+class SpecialArgName(Enum):
+ possibly_redundant_memory_format = auto()
+
+
+ArgName = Union[str, SpecialArgName]
+
+
+# This class shouldn't be created directly; instead, use/create one of the singletons below.
+@dataclass(frozen=True)
+class BaseCppType:
+ ns: Optional[str]
+ name: str
+
+ def __str__(self) -> str:
+ if self.ns is None or self.ns == "":
+ return self.name
+ return f"{self.ns}::{self.name}"
+
+
+# The set of all non-templated, valid, fully-qualified names of C++ types that are used in the codegen.
+# Templated types get their own dataclass, mainly to make namespace parsing easier.
+byteT = BaseCppType("", "uint8_t")
+charT = BaseCppType("", "int8_t")
+shortT = BaseCppType("", "int16_t")
+# It would be more symmetric for this to be called intT, but it easy to mix
+# this up with JIT int (which is int64_t in C++), so we intentionally don't
+# define intT to make it obvious when you've stuffed it up
+int32T = BaseCppType("", "int32_t")
+longT = BaseCppType("", "int64_t")
+doubleT = BaseCppType("", "double")
+floatT = BaseCppType("", "float")
+boolT = BaseCppType("", "bool")
+voidT = BaseCppType("", "void")
+
+
+class CType(ABC):
+ @abstractmethod
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
+ raise NotImplementedError
+
+ @abstractmethod
+ def cpp_type_registration_declarations(self) -> str:
+ raise NotImplementedError
+
+ @abstractmethod
+ def remove_const_ref(self) -> "CType":
+ return self
+
+
+@dataclass(frozen=True)
+class BaseCType(CType):
+ type: BaseCppType
+
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
+ return str(self.type)
+
+ # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
+ # TODO: Kill this when we eventually remove it!
+ def cpp_type_registration_declarations(self) -> str:
+ return str(self.type).replace("at::", "")
+
+ def remove_const_ref(self) -> "CType":
+ return self
+
+
+@dataclass(frozen=True)
+class ConstRefCType(CType):
+ elem: "CType"
+
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
+ if strip_ref:
+ return self.elem.cpp_type(strip_ref=strip_ref)
+ return f"const {self.elem.cpp_type()} &"
+
+ def cpp_type_registration_declarations(self) -> str:
+ return f"const {self.elem.cpp_type_registration_declarations()} &"
+
+ def remove_const_ref(self) -> "CType":
+ return self.elem.remove_const_ref()
+
+
+@dataclass(frozen=True)
+class VectorCType(CType):
+ elem: "CType"
+
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
+ # Do not pass `strip_ref` recursively.
+ return f"::std::vector<{self.elem.cpp_type()}>"
+
+ def cpp_type_registration_declarations(self) -> str:
+ return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>"
+
+ def remove_const_ref(self) -> "CType":
+ return VectorCType(self.elem.remove_const_ref())
+
+
+@dataclass(frozen=True)
+class ArrayCType(CType):
+ elem: "CType"
+ size: int
+
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
+ # Do not pass `strip_ref` recursively.
+ return f"::std::array<{self.elem.cpp_type()},{self.size}>"
+
+ def cpp_type_registration_declarations(self) -> str:
+ return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>"
+
+ def remove_const_ref(self) -> "CType":
+ return ArrayCType(self.elem.remove_const_ref(), self.size)
+
+
+@dataclass(frozen=True)
+class TupleCType(CType):
+ elems: List["CType"]
+
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
+ # Do not pass `strip_ref` recursively.
+ return f'::std::tuple<{",".join([e.cpp_type() for e in self.elems])}>'
+
+ def cpp_type_registration_declarations(self) -> str:
+ return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>'
+
+ def remove_const_ref(self) -> "CType":
+ return TupleCType([e.remove_const_ref() for e in self.elems])
+
+
+@dataclass(frozen=True)
+class MutRefCType(CType):
+ elem: "CType"
+
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
+ if strip_ref:
+ return self.elem.cpp_type(strip_ref=strip_ref)
+ return f"{self.elem.cpp_type()} &"
+
+ def cpp_type_registration_declarations(self) -> str:
+ return f"{self.elem.cpp_type_registration_declarations()} &"
+
+ def remove_const_ref(self) -> "CType":
+ return self.elem.remove_const_ref()
+
+
+# A NamedCType is short for Named C++ semantic type. A NamedCType represents a C++ type, plus
+# semantic information about what it represents. For example, consider the
+# argument "bool pin_memory"; its normal C++ type is "bool", but its C++
+# semantic type also keeps track that this represents a "pin_memory"; you can't
+# just use a random other boolean in a context where you need a "pin_memory"!
+#
+
+
+@dataclass(frozen=True)
+class NamedCType:
+ name: ArgName
+ type: CType
+
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
+ return self.type.cpp_type(strip_ref=strip_ref)
+
+ # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
+ # TODO: Kill this when we eventually remove it!
+ def cpp_type_registration_declarations(self) -> str:
+ return self.type.cpp_type_registration_declarations()
+
+ def remove_const_ref(self) -> "NamedCType":
+ return NamedCType(self.name, self.type.remove_const_ref())
+
+ def with_name(self, name: str) -> "NamedCType":
+ return NamedCType(name, self.type)
+
+
+# A binding represents any C++ binding site for a formal parameter.
+# We don't distinguish between binding sites for different APIs;
+# instead, all of the important distinctions are encoded in CType,
+# which you can use to figure out if a given Binding is appropriate
+# for use in another context. (See torchgen.api.translate)
+
+
+@dataclass(frozen=True)
+class Binding:
+ name: str
+ nctype: NamedCType
+ argument: Union[Argument, TensorOptionsArguments, SelfArgument]
+ # TODO: maybe don't represent default here
+ default: Optional[str] = None
+
+ def rename(self, name: str) -> "Binding":
+ return Binding(
+ name=name,
+ nctype=self.nctype,
+ argument=self.argument,
+ default=self.default,
+ )
+
+ @property
+ def type(self) -> str:
+ return self.nctype.cpp_type()
+
+ def no_default(self) -> "Binding":
+ return Binding(
+ name=self.name,
+ nctype=self.nctype,
+ default=None,
+ argument=self.argument,
+ )
+
+ def decl(self, *, func_ptr_cast: bool = False) -> str:
+ mb_default = ""
+ if self.default is not None:
+ mb_default = f"={self.default}"
+
+ # casting only needs to know the type
+ if func_ptr_cast:
+ return f"{self.type}"
+ else:
+ return f"{self.type} {self.name}{mb_default}"
+
+ # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
+ # TODO: Kill this when we eventually remove it!
+ def decl_registration_declarations(self) -> str:
+ type_s = self.nctype.cpp_type_registration_declarations()
+ mb_default = ""
+ if self.default is not None:
+ mb_default = f"={self.default}"
+ return f"{type_s} {self.name}{mb_default}"
+
+ def defn(self) -> str:
+ return f"{self.type} {self.name}"
+
+ def with_name(self, name: str) -> "Binding":
+ return Binding(
+ name=name, nctype=self.nctype, argument=self.argument, default=self.default
+ )
+
+
+# An Expr is a C++ expression. It has a C++ string representing its syntax,
+# as well as a CType saying what it provides.
+
+
+@dataclass(frozen=True)
+class Expr:
+ expr: str
+ type: NamedCType
diff --git a/MLPY/Lib/site-packages/torchgen/api/ufunc.py b/MLPY/Lib/site-packages/torchgen/api/ufunc.py
new file mode 100644
index 0000000000000000000000000000000000000000..01d8e9598ab5f74eafc40e86b8e0f917dd0b2e8c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/ufunc.py
@@ -0,0 +1,209 @@
+from dataclasses import dataclass
+from typing import List, Optional
+
+import torchgen.api.types as api_types
+
+from torchgen.api import cpp, structured
+from torchgen.api.types import (
+ ArgName,
+ BaseCppType,
+ BaseCType,
+ Binding,
+ ConstRefCType,
+ CType,
+ NamedCType,
+ scalarT,
+)
+from torchgen.model import (
+ Argument,
+ BaseTy,
+ BaseType,
+ DispatchKey,
+ FunctionSchema,
+ NativeFunctionsGroup,
+ Type,
+)
+
+
+def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
+ assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas"
+ return f"ufunc_{func.name.name}_{dispatch_key}"
+
+
+def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
+ return schema_kernel_name(g.out.func, dispatch_key)
+
+
+# Tensors are omitted (as they are stored in TensorIterator), everything else is
+# passed along (technically, we can pass tensors along too, it just wastes
+# argument registers)
+#
+# NB: used for CPU only
+def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]:
+ # Dispatch stubs are always plain ints
+ r = cpp.valuetype_type(t, binds=binds, symint=False)
+ if r is not None:
+ return r
+
+ if t == BaseType(BaseTy.Scalar):
+ return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
+ elif t == BaseType(BaseTy.Tensor):
+ return None
+ else:
+ raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+def opmath_type(scalar_t: BaseCppType) -> BaseCppType:
+ if scalar_t == api_types.scalar_t:
+ return api_types.opmath_t
+ raise NotImplementedError
+
+
+# NB: Tensors in constructor are stored in opmath_t, not scalar_t
+# because Tensor in constructor = its a scalar tensor partially applied =
+# it can be higher precision and we want to compute in that higher precision
+#
+# NB: CUDA only
+def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType:
+ r = cpp.valuetype_type(t, binds=binds, symint=False)
+ if r is not None:
+ return r
+
+ if t == BaseType(BaseTy.Scalar):
+ return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
+ elif t == BaseType(BaseTy.Tensor):
+ return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
+ else:
+ raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+# Only Tensors ever get passed directly to operator()
+#
+# NB: CUDA only
+# (Actually, this works for CPU too)
+def ufunctor_apply_type(
+ t: Type, *, binds: ArgName, scalar_t: BaseCppType
+) -> NamedCType:
+ if t == BaseType(BaseTy.Tensor):
+ return NamedCType(binds, BaseCType(scalar_t))
+ else:
+ raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+# The actual ufunc template function the user writes. Everything here
+# is done in the computation type. compute_t is opmath_t in CUDA and scalar_t
+# in CPU
+def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
+ r = cpp.valuetype_type(t, binds=binds, symint=False)
+ if r is not None:
+ return r
+
+ if t == BaseType(BaseTy.Scalar):
+ return NamedCType(binds, compute_t)
+ elif t == BaseType(BaseTy.Tensor):
+ return NamedCType(binds, compute_t)
+ else:
+ raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
+ return Binding(
+ nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t),
+ name=a.name,
+ default=None,
+ argument=a,
+ )
+
+
+def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
+ return Binding(
+ nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t),
+ name=a.name,
+ default=None,
+ argument=a,
+ )
+
+
+def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
+ return Binding(
+ nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t),
+ name=a.name,
+ default=None,
+ argument=a,
+ )
+
+
+@dataclass(frozen=True)
+class UfunctorBindings:
+ ctor: List[Binding]
+ apply: List[Binding]
+
+
+# ufunctors are a CUDA-only concept representing functors that take some of
+# their arguments on a host-side constructor, and the rest in the device-side
+# apply. E.g.,
+#
+# template
+# struct CUDAFunctorOnSelf_add {
+# using opmath_t = at::opmath_type;
+# opmath_t other_;
+# opmath_t alpha_;
+# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {}
+# __device__ scalar_t operator()(scalar_t self) {
+# return ufunc::add(static_cast(self), other_, alpha_);
+# }
+# };
+#
+# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
+# to the operator() definition
+def ufunctor_arguments(
+ g: NativeFunctionsGroup, *, scalar_tensor_idx: Optional[int], scalar_t: BaseCppType
+) -> UfunctorBindings:
+ ctor = []
+ apply = []
+ for a in g.functional.func.arguments.flat_non_out:
+ if a.type.is_tensor_like():
+ if scalar_tensor_idx == 0:
+ # put it in the ctor anyway
+ ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
+ scalar_tensor_idx = None
+ else:
+ if scalar_tensor_idx is not None:
+ scalar_tensor_idx -= 1
+ apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t))
+ else:
+ ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
+ assert scalar_tensor_idx is None
+ return UfunctorBindings(ctor=ctor, apply=apply)
+
+
+# ufuncs are the inner loop template functions that you wrote in ufunc/add.h
+# which do the actual computation in question. E.g.,
+#
+# template
+# C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ {
+# return self + alpha * other;
+# }
+#
+# In this file, we refer to T as compute_t which is bound by caller
+def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> List[Binding]:
+ return [
+ ufunc_argument(a, compute_t=compute_t)
+ for a in g.functional.func.arguments.flat_non_out
+ ]
+
+
+# Stubs are the DispatchStub trampolines that CPU kernels use to get to their
+# vectorized versions. E.g.,
+#
+# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
+# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
+def stub_arguments(g: NativeFunctionsGroup) -> List[Binding]:
+ # stubs drop all tensor arguments (they are implicit in the TensorIterator
+ # argument and keep everything else)
+ return [
+ r
+ for a in g.out.func.arguments.flat_non_out
+ if not a.type.is_tensor_like()
+ for r in structured.argument(a)
+ ]
diff --git a/MLPY/Lib/site-packages/torchgen/api/unboxing.py b/MLPY/Lib/site-packages/torchgen/api/unboxing.py
new file mode 100644
index 0000000000000000000000000000000000000000..60d671d024733ed05b69c0f30f043daadd904b11
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/api/unboxing.py
@@ -0,0 +1,248 @@
+from typing import List, Tuple
+
+from torchgen.api import cpp
+from torchgen.api.types import Binding, CppSignatureGroup, CType
+from torchgen.model import (
+ Argument,
+ BaseTy,
+ BaseType,
+ ListType,
+ NativeFunction,
+ OptionalType,
+ Type,
+)
+
+# This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the
+# ivalues from stack to correct arguments to the unboxed kernel, based on corresponding JIT schema. This codegen is
+# an alternative way to generate unboxing wrappers similar to the existing C++ metaprogramming approach but gets the
+# job done statically. These generated unboxing wrappers will be useful under the scenario where we need to register
+# a fixed set of operators known at compile time and thus can save some time in runtime initialization phase.
+#
+# Here's an example on how the codegen works:
+#
+# - Function Schema (source of truth)
+#
+# aten::empty.names(int[] size, *, Dimname[]? names,
+# ScalarType? dtype=None, Layout? layout=None,
+# Device? device=None, bool? pin_memory=None,
+# MemoryFormat? memory_format=None) -> Tensor
+# - Argument Conversion
+# Generates C++ code to convert an ivalue (from stack) to its underlying C++ type.
+# - int[] size
+# ```cpp
+# const c10::List size_list_in = (std::move(peek(stack, 0, 7))).toList();
+#
+# std::vector size_vec;
+# for (c10::IValue size_elem: size_list_in) {
+# int64_t size_base = size_elem.to();
+# size_vec.push_back(size_base);
+# }
+# at::ArrayRef size_list_out(size_vec);
+# ~~~~~~~~~~~~~ <-- The converted argument from ivalues in the stack.
+# Will be passed to unboxed kernel.
+# ```
+# - Dimname[]? names
+# ```cpp
+# c10::optional names_opt = (std::move(peek(stack, 1, 7))).toOptional();
+# c10::optional> names_opt_out;
+# if (names_opt.has_value()) {
+# ~~~~~~~~~~~ <-- Unwrapping optional shell
+# const c10::IValue names_opt_in = names_opt.value();
+# const c10::List names_list_in = names_opt_in.toList();
+#
+# std::vector names_vec;
+# for (c10::IValue names_elem: names_list_in) {
+# ~~~~~~~~~~~~~~~~~~~~~~~~~ <-- Unrolling list, then convert elements one by one.
+# at::Dimname names_base = names_elem.to();
+# names_vec.push_back(names_base);
+# }
+# at::ArrayRef names_list_out(names_vec);
+#
+# names_opt_out = c10::optional>(names_list_out);
+# } else {
+# names_opt_out = c10::optional>();
+# }
+# ```
+# - ScalarType? dtype (similarly for the rest of the arguments)
+# ```cpp
+# c10::optional dtype_opt = (std::move(peek(stack, 2, 7))).toOptional();
+# c10::optional dtype_opt_out;
+# if (dtype_opt.has_value()) {
+# const c10::IValue dtype_opt_in = dtype_opt.value();
+# at::ScalarType dtype_base = dtype_opt_in.to();
+# ~~~~~~~~~~~~~~~~~~~~ <-- For base types, convert ivalue to it
+# directly using ".to()" API.
+# dtype_opt_out = c10::optional(dtype_base);
+# } else {
+# dtype_opt_out = c10::optional();
+# }
+# ```
+#
+# - Unboxed Kernel Call
+# ```cpp
+# auto result_ = torch::empty(
+# size_list_out,
+# names_opt_out,
+# options,
+# memory_format_opt_out
+# );
+# ```
+#
+# - Push Result Back to Stack
+# ```cpp
+# drop(stack, 7);
+# pack(stack, std::move(result_));
+# ```
+connector = "\n\t"
+
+
+# Return unboxing function name for a NativeFunction
+def name(f: NativeFunction) -> str:
+ return f.func.name.unambiguous_name()
+
+
+# Convert all the arguments in a NativeFunction to C++ code
+def convert_arguments(f: NativeFunction) -> Tuple[List[Binding], List[str]]:
+ # we need the 'self' argument so method needs to be False
+ args = (
+ CppSignatureGroup.from_native_function(f, method=False)
+ .most_faithful_signature()
+ .arguments()
+ )
+ code_list = [
+ f"c10::IValue {args[i].name} = std::move(peek(stack, {i}, {len(args)}));"
+ for i in range(len(args))
+ ] + [""]
+ binding_list = []
+ for arg in args:
+ # expecting only Argument
+ if not isinstance(arg.argument, Argument):
+ raise Exception(
+ f"Unexpected argument type, expecting `Argument` but got {arg}"
+ )
+ argument: Argument = arg.argument
+ unboxed_name, _, code, decl = argumenttype_ivalue_convert(
+ argument.type,
+ argument.name,
+ mutable=argument.is_write,
+ )
+ code_list.extend(decl)
+ code_list.extend(code)
+ binding_list.append(arg.with_name(unboxed_name))
+ return binding_list, code_list
+
+
+# Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
+# (1) the C++ code necessary to unbox the argument
+# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
+def argumenttype_ivalue_convert(
+ t: Type, arg_name: str, *, mutable: bool = False
+) -> Tuple[str, CType, List[str], List[str]]:
+ # Unboxing is for mobile, which doesn't care about SymInts
+ ctype = cpp.argumenttype_type(
+ t=t, mutable=mutable, binds=arg_name, symint=False
+ ).type
+
+ if isinstance(t, BaseType):
+ out_name = f"{arg_name}_base"
+ code, decl = _gen_code_base_type(
+ arg_name=arg_name, out_name=out_name, ctype=ctype
+ )
+ elif isinstance(t, OptionalType):
+ out_name = f"{arg_name}_opt_out"
+ code, decl = _gen_code_optional_type(
+ arg_name=arg_name,
+ out_name=out_name,
+ t=t,
+ ctype=ctype,
+ )
+ elif isinstance(t, ListType):
+ out_name = f"{arg_name}_list_out"
+ code, decl = _gen_code_list_type(
+ arg_name=arg_name,
+ out_name=out_name,
+ t=t,
+ ctype=ctype,
+ )
+ else:
+ raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}")
+ return out_name, ctype, code, decl
+
+
+def _gen_code_base_type(
+ arg_name: str, out_name: str, ctype: CType
+) -> Tuple[List[str], List[str]]:
+ return [
+ f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
+ ], []
+
+
+def _gen_code_optional_type(
+ arg_name: str, out_name: str, t: OptionalType, ctype: CType
+) -> Tuple[List[str], List[str]]:
+ in_name = f"{arg_name}_opt_in"
+ res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name)
+ return (
+ f"""
+c10::optional {arg_name}_opt = {arg_name}.toOptional();
+{ctype.cpp_type(strip_ref=True)} {out_name};
+if ({arg_name}_opt.has_value()) {{
+ const c10::IValue {in_name} = {arg_name}_opt.value();
+ {connector.join(res_code)}
+ {out_name} = {ctype.cpp_type(strip_ref=True)}({res_name});
+}} else {{
+ {out_name} = {ctype.cpp_type(strip_ref=True)}();
+}}
+ """.split(
+ "\n"
+ ),
+ decl,
+ )
+
+
+def _gen_code_list_type(
+ arg_name: str, out_name: str, t: ListType, ctype: CType
+) -> Tuple[List[str], List[str]]:
+ in_name = f"{arg_name}_list_in"
+ elem_name = f"{arg_name}_elem"
+ code = [f"const c10::List {in_name} = {arg_name}.toList();"]
+ res_name, res_ctype, res_code, decl = argumenttype_ivalue_convert(t.elem, elem_name)
+ # handle list type with size, e.g., bool[4]
+ if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool and t.size:
+ code.extend(
+ f"""
+{ctype.cpp_type(strip_ref=True)} {out_name} = as_array<{res_ctype.cpp_type(strip_ref=True)}, {t.size}>({in_name});
+ """.split(
+ "\n"
+ )
+ )
+ # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List>
+ elif isinstance(t.elem, OptionalType):
+ code.extend(
+ f"""
+{ctype.cpp_type(strip_ref=True)} {out_name};
+for (c10::IValue {elem_name}: {in_name}) {{
+ {connector.join(res_code)}
+ {out_name}.push_back({res_name});
+}}
+ """.split(
+ "\n"
+ )
+ )
+ else:
+ # use ArrayRef as default.
+ vec_name = arg_name + "_vec"
+ # need to bring vector instantiation out of scope so that ArrayRef has valid data
+ decl.append(f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};")
+ code.extend(
+ f"""
+for (c10::IValue {elem_name}: {in_name}) {{
+ {connector.join(res_code)}
+ {vec_name}.push_back({res_name});
+}}
+{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
+ """.split(
+ "\n"
+ )
+ )
+ return code, decl
diff --git a/MLPY/Lib/site-packages/torchgen/code_template.py b/MLPY/Lib/site-packages/torchgen/code_template.py
new file mode 100644
index 0000000000000000000000000000000000000000..01784303507057a13c95d2853ca84744cd9e237a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/code_template.py
@@ -0,0 +1,96 @@
+import re
+from typing import Mapping, Match, Optional, Sequence
+
+# match $identifier or ${identifier} and replace with value in env
+# If this identifier is at the beginning of whitespace on a line
+# and its value is a list then it is treated as
+# block substitution by indenting to that depth and putting each element
+# of the list on its own line
+# if the identifier is on a line starting with non-whitespace and a list
+# then it is comma separated ${,foo} will insert a comma before the list
+# if this list is not empty and ${foo,} will insert one after.
+
+
+class CodeTemplate:
+ substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"
+ substitution = re.compile(substitution_str, re.MULTILINE)
+
+ pattern: str
+ filename: str
+
+ @staticmethod
+ def from_file(filename: str) -> "CodeTemplate":
+ with open(filename) as f:
+ return CodeTemplate(f.read(), filename)
+
+ def __init__(self, pattern: str, filename: str = "") -> None:
+ self.pattern = pattern
+ self.filename = filename
+
+ def substitute(
+ self, env: Optional[Mapping[str, object]] = None, **kwargs: object
+ ) -> str:
+ if env is None:
+ env = {}
+
+ def lookup(v: str) -> object:
+ assert env is not None
+ return kwargs[v] if v in kwargs else env[v]
+
+ def indent_lines(indent: str, v: Sequence[object]) -> str:
+ return "".join(
+ [indent + l + "\n" for e in v for l in str(e).splitlines()]
+ ).rstrip()
+
+ def replace(match: Match[str]) -> str:
+ indent = match.group(1)
+ key = match.group(2)
+ comma_before = ""
+ comma_after = ""
+ if key[0] == "{":
+ key = key[1:-1]
+ if key[0] == ",":
+ comma_before = ", "
+ key = key[1:]
+ if key[-1] == ",":
+ comma_after = ", "
+ key = key[:-1]
+ v = lookup(key)
+ if indent is not None:
+ if not isinstance(v, list):
+ v = [v]
+ return indent_lines(indent, v)
+ elif isinstance(v, list):
+ middle = ", ".join([str(x) for x in v])
+ if len(v) == 0:
+ return middle
+ return comma_before + middle + comma_after
+ else:
+ return str(v)
+
+ return self.substitution.sub(replace, self.pattern)
+
+
+if __name__ == "__main__":
+ c = CodeTemplate(
+ """\
+ int foo($args) {
+
+ $bar
+ $bar
+ $a+$b
+ }
+ int commatest(int a${,stuff})
+ int notest(int a${,empty,})
+ """
+ )
+ print(
+ c.substitute(
+ args=["hi", 8],
+ bar=["what", 7],
+ a=3,
+ b=4,
+ stuff=["things...", "others"],
+ empty=[],
+ )
+ )
diff --git a/MLPY/Lib/site-packages/torchgen/context.py b/MLPY/Lib/site-packages/torchgen/context.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e3b4772b5a4b89996c5b66c18a6d543cd8955ef
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/context.py
@@ -0,0 +1,128 @@
+import contextlib
+
+import functools
+from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union
+
+import torchgen.local as local
+from torchgen.model import (
+ BackendIndex,
+ DispatchKey,
+ NativeFunction,
+ NativeFunctionsGroup,
+ NativeFunctionsViewGroup,
+)
+from torchgen.utils import context, S, T
+
+# Helper functions for defining generators on things in the model
+
+F = TypeVar(
+ "F",
+ NativeFunction,
+ NativeFunctionsGroup,
+ NativeFunctionsViewGroup,
+ Union[NativeFunction, NativeFunctionsGroup],
+ Union[NativeFunction, NativeFunctionsViewGroup],
+)
+
+F2 = TypeVar(
+ "F2",
+ NativeFunction,
+ NativeFunctionsGroup,
+ Optional[NativeFunction],
+ bool,
+ str,
+)
+
+F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction])
+
+
+@contextlib.contextmanager
+def native_function_manager(
+ g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction]
+) -> Iterator[None]:
+ if isinstance(g, NativeFunctionsGroup):
+ # By default, we associate all errors with structured native functions
+ # with the out variant. In some cases, it might be better to have
+ # a more specific place to hang things; if so, use
+ # native_function_manager again on the inside
+ f = g.out
+ elif isinstance(g, NativeFunctionsViewGroup):
+ # We associate errors with the view operator
+ f = g.view
+ else:
+ f = g
+ with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"):
+ with local.parametrize(
+ use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
+ use_ilistref_for_tensor_lists=f.part_of_structured_group,
+ ):
+ yield
+
+
+# Given a function that operates on NativeFunction, wrap it into a new function
+# that sets some appropriate context managers for that native function.
+# YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound
+# (you will get an error if we try to access the local variables without having
+# set them).
+def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
+ @functools.wraps(func)
+ def wrapper(f: F) -> T:
+ with native_function_manager(f):
+ return func(f)
+
+ return wrapper
+
+
+def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
+ @functools.wraps(func)
+ def wrapper(f: F, f2: F2) -> T:
+ # The first native_function is assumed to be the one with the appropriate context.
+ with native_function_manager(f):
+ return func(f, f2)
+
+ return wrapper
+
+
+def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
+ @functools.wraps(func)
+ def wrapper(slf: S, f: F) -> T:
+ with native_function_manager(f):
+ return func(slf, f)
+
+ return wrapper
+
+
+def method_with_nested_native_function(
+ func: Callable[[S, F3], T]
+) -> Callable[[S, F3], T]:
+ @functools.wraps(func)
+ def wrapper(slf: S, f: F3) -> T:
+ with native_function_manager(f[0]):
+ return func(slf, f)
+
+ return wrapper
+
+
+# Convenience decorator for functions that explicitly take in a BackendIndex,
+# instead of indirectly taking one in as a closure
+def with_native_function_and_index(
+ func: Callable[[F, BackendIndex], T]
+) -> Callable[[F, BackendIndex], T]:
+ @functools.wraps(func)
+ def wrapper(f: F, backend_index: BackendIndex) -> T:
+ with native_function_manager(f):
+ return func(f, backend_index)
+
+ return wrapper
+
+
+# Convenience decorator for functions that explicitly take in a Dict of BackendIndices
+def with_native_function_and_indices(
+ func: Callable[[F, Dict[DispatchKey, BackendIndex]], T]
+) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]:
+ @functools.wraps(func)
+ def wrapper(f: F, backend_indices: Dict[DispatchKey, BackendIndex]) -> T:
+ with native_function_manager(f):
+ return func(f, backend_indices)
+
+ return wrapper
diff --git a/MLPY/Lib/site-packages/torchgen/dest/__init__.py b/MLPY/Lib/site-packages/torchgen/dest/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..958b5a29017d9886efad778a413884248347ea7c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/dest/__init__.py
@@ -0,0 +1,19 @@
+from .lazy_ir import (
+ generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes,
+ GenLazyIR as GenLazyIR,
+ GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition,
+ GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition,
+)
+from .native_functions import (
+ compute_native_function_declaration as compute_native_function_declaration,
+)
+from .register_dispatch_key import (
+ gen_registration_headers as gen_registration_headers,
+ gen_registration_helpers as gen_registration_helpers,
+ RegisterDispatchKey as RegisterDispatchKey,
+)
+from .ufunc import (
+ compute_ufunc_cpu as compute_ufunc_cpu,
+ compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel,
+ compute_ufunc_cuda as compute_ufunc_cuda,
+)
diff --git a/MLPY/Lib/site-packages/torchgen/dest/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/dest/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b5e0ce048ef8216fadba086cf42fb5781720c15
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/dest/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17d07e8e3eb35d147d0b9908ace144a7f2b97297
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1afe09868023ab018f1cf3176db22837b5cebdab
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/dest/__pycache__/native_functions.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/dest/__pycache__/native_functions.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7eb3d1c9dcce7e476952d40f3b9cb8435054a229
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/dest/__pycache__/native_functions.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..735b617f81556f6e2760451c9f240fb177ec409d
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/dest/__pycache__/ufunc.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/dest/__pycache__/ufunc.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f392de98f02c4fe035a540cab53e602d394b780
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/dest/__pycache__/ufunc.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/dest/lazy_ir.py b/MLPY/Lib/site-packages/torchgen/dest/lazy_ir.py
new file mode 100644
index 0000000000000000000000000000000000000000..84a00001e5e5cd415bf835ef07e3621624c9d6ae
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/dest/lazy_ir.py
@@ -0,0 +1,707 @@
+import itertools
+from abc import ABC
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torchgen.api.dispatcher as dispatcher
+from torchgen.api.lazy import (
+ getValueT,
+ isValueType,
+ LazyArgument,
+ LazyIrProperties,
+ LazyIrSchema,
+ tensorListValueT,
+)
+from torchgen.api.translate import translate
+from torchgen.api.types import (
+ BaseCType,
+ Binding,
+ deviceT,
+ DispatcherSignature,
+ kernel_signature,
+ NativeSignature,
+ OptionalCType,
+ VectorCType,
+)
+from torchgen.context import method_with_native_function
+from torchgen.dest.lazy_ts_lowering import ts_lowering_body
+from torchgen.model import (
+ Argument,
+ BackendIndex,
+ BackendMetadata,
+ BaseTy,
+ BaseType,
+ FunctionSchema,
+ ListType,
+ NativeFunction,
+ NativeFunctionsGroup,
+)
+
+
+def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
+ """
+ Given a LazyArgument,
+ generate a c++ string for materializing an rvalue of that arg for passing into
+ a lazy Node constructor.
+ """
+
+ # TODO: Matching on CType seems wrong; should be matching on Type
+ if isValueType(arg.lazy_type):
+ if isinstance(arg.lazy_type, BaseCType):
+ if arg.is_wrapped_scalar:
+ return f"node_{arg.name}"
+ elif arg.lazy_type.type is tensorListValueT:
+ return f"lazy_{arg.name}_tensorlist"
+ elif arg.is_symint_or_list:
+ return f"GetSymIntValue({arg.name})"
+ return f"lazy_{arg.name}->GetIrValue()"
+ elif isinstance(arg.lazy_type, OptionalCType):
+ if arg.is_symint_or_list:
+ # TODO: I don't understand when you should put lazy_ in the name
+ # or not
+ return f"{arg.name} ? c10::make_optional(GetSymIntValue(*{arg.name})) : c10::nullopt"
+ elif arg.is_wrapped_scalar:
+ return f"node_{arg.name}"
+ return (
+ f"lazy_{arg.name} ? "
+ f"c10::make_optional(lazy_{arg.name}->GetIrValue()) : "
+ "c10::nullopt"
+ )
+ else:
+ raise AssertionError(
+ f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
+ )
+ else:
+ # NB: this is here because right now we aren't treating SymInt[] as a
+ # value type; when we do this needs to move above
+ # NB: we cannot test arg.lazy_type as we've already specified it is an
+ # int64_t and so we cannot distinguish between SymInt and int64_t
+ if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
+ BaseTy.SymInt
+ ):
+ if arg.symint:
+ return f"GetSymIntArrayRefValue({arg.name})"
+ else:
+ return f"std::vector({arg.name}.begin(), {arg.name}.end())"
+ elif isinstance(arg.lazy_type, VectorCType) and isinstance(
+ arg.lazy_type.elem, BaseCType
+ ):
+ return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
+ elif (
+ isinstance(arg.lazy_type, OptionalCType)
+ and isinstance(arg.lazy_type.elem, VectorCType)
+ and isinstance(arg.lazy_type.elem.elem, BaseCType)
+ ):
+ return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})"
+ else:
+ return f"{arg.name}"
+
+
+def node_ctor_inputs(schema: LazyIrSchema) -> str:
+ """
+ Produce a formatted string with the arguments as passed into the constructor of a node class.
+ """
+ node_ctor_values = [
+ node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args()
+ ]
+ return ", ".join(node_ctor_values)
+
+
+def gen_fallback_code(
+ schema: LazyIrSchema,
+ sig: Union[DispatcherSignature, NativeSignature],
+ overload_name: str,
+) -> str:
+ """
+ Generate code that falls back to eager conditioned on a predicate
+ """
+ dispatcher_sig = DispatcherSignature.from_schema(schema.func)
+ exprs = translate(sig.arguments(), dispatcher_sig.arguments())
+ fallback_args = ",\n ".join([a.expr for a in exprs])
+ if len(overload_name):
+ aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
+ else:
+ aten_op_str = f"ATEN_OP({schema.aten_name})"
+ return f"""
+ if (force_eager_fallback({aten_symbol(schema)})) {{
+ return at::native::call_fallback_fn_symint<<c_eager_fallback, {aten_op_str}>::call(
+ {fallback_args}
+ );
+ }}
+"""
+
+
+def aten_symbol(schema: LazyIrSchema) -> str:
+ missing_interned_strings = {
+ "sigmoid_backward",
+ }
+ if schema.aten_name in missing_interned_strings:
+ return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
+
+ if not schema.aten_name.startswith("at::"):
+ return f"at::aten::{schema.aten_name}"
+ else:
+ return schema.aten_name
+
+
+# converts all tensor-like arguments to meta tensors. Returns:
+# (1) a string containing all of the logic that does the conversions.
+# (2) a context, to be used by translate(), with all of the relevant bindings.
+def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
+ context: List[Binding] = []
+ unwrapped_tensor_args: List[str] = []
+ for arg in sig.arguments():
+ if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
+ unwrapped_name = f"{arg.name}_meta"
+ unwrapped_tensor_args.append(
+ f"auto {unwrapped_name} = to_meta({arg.name});"
+ )
+ context.append(arg.with_name(unwrapped_name))
+ else:
+ context.append(arg)
+ unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
+ return unwrap_tensor_args_str, context
+
+
+@dataclass(frozen=True)
+class GenLazyIR(ABC):
+ backend_index: BackendIndex
+ backend_name: str
+ node_base: str
+ use_lazy_shape: bool
+
+ @method_with_native_function
+ def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
+ func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
+ metadata = self.backend_index.get_kernel(
+ f.functional if isinstance(f, NativeFunctionsGroup) else f
+ )
+ schema = LazyIrSchema(
+ func, symint=metadata is not None and metadata.supports_symint()
+ )
+ return self.gen(schema)
+
+ # there is no lowering functionality generated unless this IR base class is subclassed and
+ # implemented as a backend-specific node
+ def lowering_function(self, schema: LazyIrSchema) -> str:
+ return ""
+
+ def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
+ return ""
+
+ def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
+ return f"""bool CanBeReused({node_ctor_args}) const {{
+ return false;
+ }}"""
+
+ def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
+ value_args = schema.filtered_args(values=True, scalars=False)
+ # backends can customize the way the node base class constructor is called,
+ # as long as all of its arguments can be generated from information available from the schema
+ base_ctor_value_args_list = []
+ for arg in value_args:
+ if isinstance(arg.lazy_type, (BaseCType, VectorCType)):
+ base_ctor_value_args_list.append(f"{arg.name}")
+ elif isinstance(arg.lazy_type, OptionalCType):
+ base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
+ else:
+ raise AssertionError(
+ f"Unsupported type ({arg.lazy_type}) - add support if necessary"
+ )
+ base_ctor_value_args = ", ".join(base_ctor_value_args_list)
+
+ scalar_args = schema.filtered_args(values=False, scalars=True)
+
+ # Shape construction.
+ # Conditionally build shape depending on specified shape property
+ if schema.properties.ShapePrecompute:
+ shape_ctor_arg = "std::move(shapes),"
+ elif schema.properties.ShapeCompute:
+ shape_args = [a.name for a in value_args]
+ shape_args.extend(a.name for a in scalar_args)
+ shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)}),"
+ elif schema.properties.ShapeCache:
+ shape_args = [f"operand({i})" for i in range(len(value_args))]
+ shape_args.extend(a.name for a in scalar_args)
+ shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }},"
+ else:
+ shape_ctor_arg = ""
+
+ scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args)
+
+ return f"""{self.node_base}(
+ {schema.node_name}::ClassOpKind(),
+ OpList{{{base_ctor_value_args}}},
+ {shape_ctor_arg}
+ /* num_outputs */ {len(schema.returns)},
+ torch::lazy::MHash({scalar_hashes}))"""
+
+ def gen(self, schema: LazyIrSchema) -> List[str]:
+ opkind = schema.opkind or aten_symbol(schema)
+
+ # for now, we just want one IR class decl and soon after also the method defs
+ # and we use the functional version not out/inplace.
+ all_args = schema.filtered_args()
+ value_args = schema.filtered_args(values=True, scalars=False)
+ scalar_args = schema.filtered_args(values=False, scalars=True)
+
+ ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
+ reuse_ctor_args = ", ".join(ctor_args)
+ if self.use_lazy_shape and schema.properties.ShapePrecompute:
+ ctor_args.append("std::vector&& shapes")
+ node_ctor_args = ", ".join(ctor_args)
+
+ scalar_initializers = ",\n ".join(
+ [
+ # This code is just special casing the mapping from string_view -> strings
+ f"{a.name}({a.name}.has_value() ? c10::make_optional(std::string(*{a.name})) : c10::nullopt)"
+ if a.lazy_type.cpp_type() == "c10::optional"
+ else f"{a.name}({a.name})"
+ for a in scalar_args
+ ]
+ )
+ if len(scalar_initializers):
+ scalar_initializers = f",\n {scalar_initializers}"
+ scalar_decls = "\n ".join(
+ [
+ f"std::string {a.name};"
+ if a.lazy_type.cpp_type() == "c10::string_view"
+ else f"c10::optional {a.name};"
+ if a.lazy_type.cpp_type() == "c10::optional"
+ else f"{a.lazy_type.cpp_type()} {a.name};"
+ for a in scalar_args
+ ]
+ )
+ optional_values = [
+ arg.name
+ for arg in schema.filtered_args(values=True, scalars=False)
+ if isinstance(arg.lazy_type, OptionalCType)
+ ]
+ has_optional_decls = "\n ".join(
+ [f"bool has_{value}: 1;" for value in optional_values]
+ )
+ has_optional_defs = "\n ".join(
+ [f"has_{value} = !!{value};" for value in optional_values]
+ )
+ members_to_string = []
+ for arg in scalar_args:
+ if isinstance(arg.lazy_type, OptionalCType):
+ value = f"{arg.name}.value()"
+ if arg.is_generator:
+ value = '"torch.Generator()"'
+ members_to_string.append(
+ f"""if ({arg.name}.has_value()) {{
+ ss << ", {arg.name}=" << {value};
+ }} else {{
+ ss << ", {arg.name}=null";
+ }}"""
+ )
+ else:
+ members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};')
+ members_to_string_str = "\n ".join(members_to_string)
+
+ return [
+ f"""\
+class {schema.node_name} : public {self.node_base} {{
+ public:
+ static torch::lazy::OpKind ClassOpKind() {{
+ return torch::lazy::OpKind({opkind});
+ }}
+
+ {schema.node_name}({node_ctor_args})
+ : {self.node_base_ctor_call(schema)}{scalar_initializers}
+ {{
+ {has_optional_defs}
+ }}
+
+ std::string ToString() const override {{
+ std::stringstream ss;
+ ss << {self.node_base}::ToString();
+ {members_to_string_str}
+ return ss.str();
+ }}
+
+ {self.create_function(schema, reuse_ctor_args)}
+
+ {self.can_be_reused_function(schema, reuse_ctor_args)}
+
+ {self.lowering_function(schema)}
+
+ {scalar_decls}
+ {has_optional_decls}
+
+}};
+
+""",
+ ]
+
+
+@dataclass(frozen=True)
+class GenTSLazyIR(GenLazyIR):
+ def lowering_function(self, schema: LazyIrSchema) -> str:
+ signature = """
+ torch::lazy::TSOpVector Lower(
+ std::shared_ptr function,
+ torch::lazy::TSLoweringContext* loctx) const override"""
+
+ if schema.properties.LowerDeclOnly:
+ return f"{signature};"
+ elif schema.properties.Lower:
+ return f"""{signature} {{
+ {ts_lowering_body(schema)}
+ }}
+ """
+ else:
+ return ""
+
+ def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
+ signature = f"static NodePtr Create({node_ctor_args})"
+ if schema.properties.CreateFnDeclOnly:
+ return f"{signature};"
+ elif not schema.properties.CreateFn:
+ return ""
+ return f"""{signature} {{
+ return ReuseOrMakeNode<{schema.node_name}>(data);
+ }}"""
+
+ def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
+ signature = f"bool CanBeReused({node_ctor_args}) const"
+ if schema.properties.CanBeReusedDeclOnly:
+ return f"{signature};"
+ elif not schema.properties.CanBeReused:
+ return ""
+ value_comparison = []
+ for arg in itertools.chain(schema.positional_values, schema.keyword_values):
+ if isinstance(arg.lazy_type, OptionalCType):
+ value_comparison.append(
+ f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)"
+ )
+ else:
+ value_comparison.append(f"operand(i++) == {arg.name}")
+ for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars):
+ if isinstance(arg.lazy_type, OptionalCType):
+ value_comparison.append(
+ f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))"
+ )
+ else:
+ value_comparison.append(f"this->{arg.name} == {arg.name}")
+ value_comparison_str = " &&\n ".join(value_comparison)
+
+ return f"""{signature} {{
+ size_t i = 0;
+ return ({value_comparison_str});
+ }}"""
+
+
+@dataclass(frozen=True)
+class GenLazyNativeFuncDefinition:
+ class_method_name: str
+ backend_index: BackendIndex
+ tensor_class: str
+ gen_forced_fallback_code: bool
+ backend_namespace: str
+ get_tensorlist: str
+ get_tensor_or_wrap_number: str
+ try_get_tensor: str
+ metrics_counter: str
+ create_tensor: str
+ create_from_first_tensor: bool
+ create_aten_from_ltc_tensor: str
+ tuple_aten_from_ltc_tensors: str
+ lazy_tensor_ptr: str
+ get_device_fn: str
+
+ def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
+ value_args = schema.filtered_args(values=True, scalars=False)
+ # Generates lazy_{name} variables for LazyTensors wrapping input tensors
+ lazy_tensor_decls: List[str] = []
+ for arg in value_args:
+ if arg.is_wrapped_scalar:
+ if isinstance(arg.lazy_type, OptionalCType):
+ lazy_tensor_decls.append(
+ f"""auto node_{arg.name} = {arg.name} ?
+ c10::make_optional(torch::lazy::LazyGraphExecutor::Get()->
+ GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)):
+ c10::nullopt;"""
+ )
+ else:
+ lazy_tensor_decls.append(
+ f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()->
+ GetIrValueForScalarFromCodegen({arg.name}, *common_device);"""
+ )
+ elif arg.is_symint_or_list:
+ continue # values are extracted in isValueType
+ elif isinstance(arg.lazy_type, BaseCType):
+ if arg.lazy_type.type is tensorListValueT:
+ lazy_tensor_decls.append(
+ f"auto lazy_{arg.name}_tensorlist = "
+ f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});"
+ )
+ else:
+ lazy_tensor_decls.append(
+ f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
+ f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);"
+ )
+ elif isinstance(arg.lazy_type, OptionalCType):
+ assert arg.lazy_type.elem == BaseCType(getValueT()), arg.lazy_type.elem
+ # TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it
+ # until we encounter a real world example.
+ lazy_tensor_decls.append(
+ f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
+ f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));"
+ )
+ else:
+ raise AssertionError(
+ f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
+ )
+ return ("\n ").join(lazy_tensor_decls)
+
+ def force_eager_fallback(
+ self,
+ func: NativeFunction,
+ schema: LazyIrSchema,
+ metadata: BackendMetadata,
+ sig: Union[DispatcherSignature, NativeSignature],
+ ) -> str:
+ if self.gen_forced_fallback_code:
+ return gen_fallback_code(
+ schema, sig, overload_name=func.func.name.overload_name
+ )
+ return ""
+
+ def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str:
+ return f"{self.metrics_counter};"
+
+ def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str:
+ value_args = schema.filtered_args(values=True, scalars=False)
+ scalar_args = schema.filtered_args(values=False, scalars=True)
+ value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
+ optional_device = OptionalCType(BaseCType(deviceT))
+ optional_devices = [
+ a.name for a in scalar_args if a.lazy_type == optional_device
+ ]
+ assert (
+ len(value_types_names) > 0 or len(optional_devices) > 0
+ ), "Expected at least one Value or Device type"
+ get_device_str = (
+ f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})"
+ )
+ return f"""auto common_device = {get_device_str};
+ TORCH_INTERNAL_ASSERT(common_device);
+ """
+
+ def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str:
+ metadata = self.backend_index.get_kernel(func)
+ assert metadata is not None
+ all_args = schema.filtered_args()
+ returns_length = len(schema.returns)
+ # call the meta kernel if it exists, to compute output shape/dtype for our IR
+ # Note [Generated LTC Shape Functions]
+ # LTC uses meta tensors from core to do shape inference when possible, and otherwise
+ # we generate a shape function declaration that needs to be manually implemented.
+ # How do we detect which ops are eligible to use meta tensors?
+ # In general we should be able to use meta tensors not just on structured operators,
+ # but also on composite operators that are implemented in terms of structured kernels.
+ # We don't currently have a way of knowing at codegen time which ops are implemented that way.
+ # This is the case for all view and view_copy operators however, so we're going to
+ # use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them).
+ is_view_copy_op = "view_copy" in func.tags
+ is_structured = func.structured or func.structured_delegate is not None
+ if is_structured or is_view_copy_op:
+ meta_out = """
+std::vector shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
+ if returns_length > 1:
+
+ def this_shape(i: int) -> str:
+ return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())"
+
+ shapes_str = ",".join([this_shape(i) for i in range(returns_length)])
+ meta_out = "std::vector shapes{" + shapes_str + "};"
+
+ # Convert tensor args to the meta device and call it.
+ # (We can't pass in the input tensors directly, because they are "functional wrappers".
+ # If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.)
+ # Even at::meta:: functions might redispatch, e.g. if they call into view ops.
+ dispatcher_sig = DispatcherSignature.from_schema(func.func)
+ meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
+ meta_call_args = [
+ e.expr
+ for e in translate(
+ meta_call_ctx, dispatcher_sig.arguments(), method=False
+ )
+ ]
+ if is_view_copy_op:
+ # view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel
+ assert func.has_composite_explicit_autograd_non_functional_kernel
+ dispatch_ns = "compositeexplicitautogradnonfunctional"
+ else:
+ dispatch_ns = "meta"
+ aten_name = schema.aten_name
+ # TODO: this is trolling
+ if func.func.has_symint() and metadata.supports_symint():
+ aten_name += "_symint"
+ shape_str = f"""\
+ {meta_conversion_str}
+ auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)});
+ {meta_out}"""
+ else:
+ shape_sig = ComputeShapeSignature(
+ metadata.kernel, func, symint=metadata.supports_symint()
+ )
+ shape_str = f"""
+ auto shapes = {shape_sig.shape_call};"""
+
+ shape_str += f"""
+ TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});"""
+
+ # Calculating which dimensions are symbolic
+ func_schema_str = "aten::" + str(func.func)
+ shape_str += f"""
+ if(torch::lazy::symbolicShapeEnabled()){{
+ std::vector inputs = {{ {', '.join(str(a.name) for a in all_args)} }};
+ const char* schema_str = "{func_schema_str}";
+ applySymbolicShapesOnLT(schema_str, inputs, shapes);
+ }}
+ """
+ return shape_str
+
+ def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str:
+ node_ctor_input_str = node_ctor_inputs(schema)
+ return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str});
+ if (!node) {{
+ {self.shape_inference(func, schema)}
+ node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes));
+ CacheNode(node);
+ }}
+ """
+
+ def create_lazy_tensor(self, first_tensor_name: Optional[str] = None) -> str:
+ # xla uses an instance method for tensor creation, for the time being
+ if self.create_from_first_tensor:
+ # TODO(whc) remove this if XLA switches to using static method for creation
+ assert (
+ first_tensor_name is not None
+ ), "Requires first tensor to create lazy tensor"
+ return f"{first_tensor_name}.{self.create_tensor}"
+ return f"{self.backend_namespace}::{self.create_tensor}"
+
+ def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str:
+ returns_length = len(schema.returns)
+ value_args = schema.filtered_args(values=True, scalars=False)
+ value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
+ first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None
+ bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}(
+ {self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
+
+ if returns_length > 1:
+ assert (
+ len(value_types_names) > 0
+ ), "Code below assumes there is at least one tensor arg"
+ bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
+ for (int i = 0; i < {returns_length}; i++) {{
+ lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
+ }}
+ auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);"""
+
+ if schema.name.name.inplace or func.func.is_out_fn():
+ assert returns_length == 1, (
+ "We assumed there was no such case where an op is an in-place variant "
+ f"and has tuple outputs, but got tuple of len {returns_length}."
+ )
+ bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node);
+ auto& result = {first_tensor_name};"""
+
+ bridge_str += """
+ return result;"""
+ return bridge_str
+
+ @method_with_native_function
+ def __call__(self, func: NativeFunction) -> List[str]:
+ sig = kernel_signature(func, self.backend_index)
+ metadata = self.backend_index.get_kernel(func)
+ assert metadata is not None
+ schema = LazyIrSchema(func.func, symint=metadata.supports_symint())
+ return [
+ f"""\
+ {sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
+ {self.force_eager_fallback(func, schema, metadata, sig)}
+ {self.metrics(func, schema)}
+ {self.get_device(func, schema)}
+ {self.lazy_tensor_decls(func, schema)}
+ {self.build_ir_node(func, schema)}
+ {self.return_aten_tensor(func, schema)}
+ }}\n
+ """
+ ]
+
+
+class ComputeShapeSignature:
+ """
+ Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
+ """
+
+ def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool):
+ self.__schema = LazyIrSchema(f.func, symint=symint)
+ self.__dispatch_args = ", ".join(
+ [a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
+ )
+ self.__call_args = ", ".join(
+ [f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)]
+ )
+ self.__kernel_name = kernel_name
+
+ def __decl_suffix(self) -> str:
+ return f"{self.__kernel_name}({self.__dispatch_args})"
+
+ def __call_suffix(self) -> str:
+ return f"{self.__kernel_name}({self.__call_args})"
+
+ @property
+ def shape_decl(self) -> str:
+ return f"TORCH_API std::vector compute_shape_{self.__decl_suffix()}"
+
+ @property
+ def shape_call(self) -> str:
+ return f"torch::lazy::compute_shape_{self.__call_suffix()}"
+
+
+@dataclass(frozen=True)
+class GenLazyShapeInferenceDefinition:
+ backend_index: BackendIndex
+ tensor_class: str
+
+ @method_with_native_function
+ def __call__(self, f: NativeFunction) -> List[str]:
+ sig = kernel_signature(f, self.backend_index)
+ metadata = self.backend_index.get_kernel(f)
+ assert metadata is not None
+
+ # See Note [Generated LTC Shape Functions]
+ is_view_copy_op = "view_copy" in f.tags
+ is_structured = f.structured or f.structured_delegate is not None
+ if is_structured or is_view_copy_op:
+ return []
+ else:
+ shape_sig = ComputeShapeSignature(
+ metadata.kernel, f, symint=metadata.supports_symint()
+ )
+ return ["\n".join([f"{shape_sig.shape_decl};"])]
+
+
+def generate_non_native_lazy_ir_nodes(
+ non_native: List[Dict[str, Any]], gen_lazy_ir: GenLazyIR
+) -> List[str]:
+ """Generate the non-native lazy IR node classes"""
+ nodes = []
+ for op in non_native:
+ # Set default properties for Non-Native IRs
+ properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly")
+ for p in op.get("properties", []):
+ setattr(properties, p, True)
+
+ # non-native is assumed to want symint bindings if you wrote symint
+ schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True)
+ schema.opkind = op.get("opkind")
+ nodes.append(gen_lazy_ir.gen(schema)[0])
+
+ return nodes
diff --git a/MLPY/Lib/site-packages/torchgen/dest/lazy_ts_lowering.py b/MLPY/Lib/site-packages/torchgen/dest/lazy_ts_lowering.py
new file mode 100644
index 0000000000000000000000000000000000000000..1efbd63d7e7722d39c314afdf5474f80a5994c28
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/dest/lazy_ts_lowering.py
@@ -0,0 +1,48 @@
+from torchgen.api.lazy import LazyArgument, LazyIrSchema
+from torchgen.api.types import OptionalCType
+
+
+def ts_lowering_body(schema: LazyIrSchema) -> str:
+ # for now, we just want one IR class decl and soon after also the method defs
+ # and we use the functional version not out/inplace.
+ emplace_arguments = []
+
+ def get_value(arg: LazyArgument) -> str:
+ if isinstance(arg.lazy_type, OptionalCType):
+ return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr"
+ return "loctx->GetOutputOp(operand(i++))"
+
+ for arg in schema.positional_args:
+ if arg.is_lazy_value:
+ emplace_arguments.append(get_value(arg))
+ continue
+ emplace_arguments.append(f'"{arg.name}", {arg.name}')
+
+ emplace_arguments_str = "\n ".join(
+ [f"arguments.emplace_back({a});" for a in emplace_arguments]
+ )
+ emplace_kwarg_values = [
+ f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values
+ ]
+ emplace_kwarg_scalars = [
+ f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars
+ ]
+ emplace_kwarguments = "\n ".join(
+ [
+ f"kwarguments.emplace_back({a});"
+ for a in emplace_kwarg_values + emplace_kwarg_scalars
+ ]
+ )
+ return f"""\
+ std::vector arguments;
+ std::vector kwarguments;
+ arguments.reserve({len(emplace_arguments)});
+ kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
+ size_t i = 0;
+ {emplace_arguments_str}
+ {emplace_kwarguments}
+ torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
+ TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
+
+ return {schema.aten_name}_out;
+"""
diff --git a/MLPY/Lib/site-packages/torchgen/dest/native_functions.py b/MLPY/Lib/site-packages/torchgen/dest/native_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dd57cc1a839680b2b8cab10ef80259d5282ee82
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/dest/native_functions.py
@@ -0,0 +1,64 @@
+from typing import List, Optional, Union
+
+import torchgen.api.meta as meta
+import torchgen.api.structured as structured
+from torchgen.api.types import kernel_signature
+
+from torchgen.context import with_native_function_and_index
+from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
+from torchgen.utils import mapMaybe
+
+
+@with_native_function_and_index
+def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional[str]:
+ sig = kernel_signature(f, backend_index)
+ metadata = backend_index.get_kernel(f)
+ if metadata is None:
+ return None
+ if "legacy::" in metadata.kernel:
+ return None
+ else:
+ prefix = "static" if backend_index.external else "TORCH_API"
+ return f"{prefix} {sig.decl(name=metadata.kernel)};"
+
+
+@with_native_function_and_index
+def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]:
+ meta_name = meta.name(g)
+ out_args = structured.impl_arguments(g)
+ metadata = backend_index.get_kernel(g)
+ if metadata is None:
+ return []
+ prefix = "" if backend_index.external else "TORCH_API "
+ return [
+ f"""\
+struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
+void impl({', '.join(a.decl() for a in out_args)});
+}};
+"""
+ ]
+
+
+# Generates NativeFunctions.h, a list of forward declarations of all
+# actual kernel definitions we keep in aten/src/ATen/native/
+@with_native_function_and_index
+def compute_native_function_declaration(
+ g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex
+) -> List[str]:
+ metadata = backend_index.get_kernel(g)
+ if isinstance(g, NativeFunctionsGroup):
+ if metadata is not None and metadata.structured:
+ if backend_index.external:
+ # Structured hasn't been tested with external backends yet.
+ raise AssertionError(
+ "Structured external backend functions are not implemented yet."
+ )
+ else:
+ return gen_structured(g, backend_index)
+ else:
+ return list(
+ mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())
+ )
+ else:
+ x = gen_unstructured(g, backend_index)
+ return [] if x is None else [x]
diff --git a/MLPY/Lib/site-packages/torchgen/dest/register_dispatch_key.py b/MLPY/Lib/site-packages/torchgen/dest/register_dispatch_key.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d7260bd925e7e7cf6902b0572e66877c355bd61
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/dest/register_dispatch_key.py
@@ -0,0 +1,989 @@
+import itertools
+import textwrap
+from dataclasses import dataclass
+from typing import List, Literal, Optional, Tuple, Union
+
+import torchgen.api.cpp as cpp
+import torchgen.api.meta as meta
+import torchgen.api.structured as structured
+from torchgen.api.translate import translate
+from torchgen.api.types import (
+ BaseCType,
+ Binding,
+ ConstRefCType,
+ CppSignature,
+ CppSignatureGroup,
+ DispatcherSignature,
+ Expr,
+ kernel_signature,
+ MutRefCType,
+ NamedCType,
+ NativeSignature,
+ tensorT,
+)
+
+from torchgen.context import method_with_native_function, native_function_manager
+from torchgen.model import (
+ Argument,
+ BackendIndex,
+ DeviceCheckType,
+ DispatchKey,
+ gets_generated_out_inplace_wrapper,
+ is_cuda_dispatch_key,
+ NativeFunction,
+ NativeFunctionsGroup,
+ SchemaKind,
+ TensorOptionsArguments,
+)
+from torchgen.selective_build.selector import SelectiveBuilder
+from torchgen.utils import assert_never, mapMaybe, Target
+
+
+def gen_registration_headers(
+ backend_index: BackendIndex,
+ per_operator_headers: bool,
+ rocm: bool,
+) -> List[str]:
+ if per_operator_headers:
+ headers = ["#include "]
+ else:
+ headers = ["#include "]
+
+ if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta):
+ headers.append("#include ")
+ elif backend_index.dispatch_key == DispatchKey.CUDA:
+ if rocm:
+ headers.append("#include ")
+ else:
+ headers.append("#include ")
+ elif backend_index.dispatch_key == DispatchKey.MPS:
+ headers.append("#include ")
+ elif per_operator_headers:
+ headers += [
+ "#include ",
+ "#include ",
+ "#include ",
+ "#include ",
+ ]
+ else:
+ headers.append("#include ")
+
+ return headers
+
+
+def gen_empty_impl_names(
+ backend_index: BackendIndex,
+) -> Tuple[Optional[str], Optional[str]]:
+ empty_impl = None
+ empty_strided_impl = None
+
+ if backend_index.dispatch_key in (
+ DispatchKey.Meta,
+ DispatchKey.CPU,
+ DispatchKey.CUDA,
+ DispatchKey.MPS,
+ ):
+ dispatch = str(backend_index.dispatch_key).lower()
+ empty_impl = f"at::detail::empty_{dispatch}"
+ empty_strided_impl = f"at::detail::empty_strided_{dispatch}"
+ elif backend_index.dispatch_key in (
+ DispatchKey.CompositeExplicitAutogradNonFunctional,
+ DispatchKey.QuantizedCPU,
+ DispatchKey.QuantizedCUDA,
+ ):
+ empty_impl = "at::empty"
+ empty_strided_impl = "at::empty_strided"
+
+ return empty_impl, empty_strided_impl
+
+
+def gen_create_out_helper(backend_index: BackendIndex) -> List[str]:
+ if backend_index.dispatch_key == DispatchKey.Meta:
+ empty_options = "options.device(at::kMeta)"
+ else:
+ empty_options = "options"
+
+ empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index)
+ if empty_impl is None:
+ return []
+
+ return [
+ f"""
+Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
+ if (strides.empty()) {{
+ return {empty_impl}(sizes, {empty_options});
+ }} else {{
+ return {empty_strided_impl}(sizes, strides, {empty_options});
+ }}
+}}
+"""
+ ]
+
+
+def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> List[str]:
+ _, empty_strided_impl = gen_empty_impl_names(backend_index)
+ return (
+ []
+ if empty_strided_impl is None
+ else [
+ f"""
+c10::optional maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
+ if (out.strides() != strides) {{
+ return {empty_strided_impl}(sizes, strides, options);
+ }}
+ return c10::nullopt;
+}}
+"""
+ ]
+ )
+
+
+def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]:
+ if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
+ # The function isn't used by this key (since only functional ops have a kernel for this key),
+ # so we need to not include it to avoid a defined-but-not-used error.
+ return []
+ return [
+ """
+void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
+ TORCH_CHECK(options.dtype() == out.dtype(),
+ "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
+ TORCH_CHECK(options.device() == out.device(),
+ "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
+ const bool resized = at::native::resize_output(out, sizes);
+ // Only restride if a resize occurred; otherwise we ignore the (advisory)
+ // strides from the meta function and directly use the output tensor's
+ // preexisting strides
+ if (resized) {
+ if (!strides.empty()) {
+ TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
+ // TODO: avoid the redispatch here
+ out.as_strided_(sizes, strides);
+ } else if (options.memory_format_opt().has_value()) {
+ out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
+ }
+ }
+}
+"""
+ ]
+
+
+def gen_check_inplace_helper(backend_index: BackendIndex) -> List[str]:
+ return [
+ """
+void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
+ // These checks are needed on those operators that:
+ // 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm')
+ // 2) have particular typing rules (e.g. 'cumsum' and 'cumprod')
+ // For other operators (e.g. 'add'), 'TensorIterator' already checks
+ // these things separately.
+ TORCH_CHECK(options.dtype() == self.dtype(),
+ "Bad in-place call: ",
+ "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match");
+ TORCH_CHECK(options.device() == self.device(),
+ "Bad in-place call: ",
+ "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match");
+ TORCH_CHECK(sizes == self.sizes(),
+ "Bad in-place call: ",
+ "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match");
+}
+"""
+ ]
+
+
+def gen_registration_helpers(backend_index: BackendIndex) -> List[str]:
+ return [
+ *gen_create_out_helper(backend_index),
+ *gen_resize_out_helper(backend_index),
+ *gen_check_inplace_helper(backend_index),
+ *gen_maybe_create_proxy_helper(backend_index),
+ ]
+
+
+# Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp).
+#
+# - The primary function of this file is to register all of the
+# implementations for the given dispatch key to the dispatcher,
+# so they are available for use in PyTorch. If dispatch is
+# None, we generate schema (def) registrations and catchall
+# registrations.
+# - The secondary function of this file is to generate a wrapper
+# around functions. In CPUType these wrappers do nothing
+# (and should be removed), but in other cases they handle
+# DeviceGuard. A small extra benefit of wrappers is they
+# are not overloaded, so they can be used in the registration
+# API without having to disambiguate which overload you want
+# (as would be the case if you directly registered native::
+# functions).
+# - The tertiary function of this file is to generate *static*
+# cpp API bindings which can be used to bypass dispatcher
+# directly to kernels, but with user-friendly cpp-style API
+@dataclass(frozen=True)
+class RegisterDispatchKey:
+ backend_index: BackendIndex
+
+ target: Literal[
+ Target.ANONYMOUS_DEFINITION,
+ Target.NAMESPACED_DEFINITION,
+ Target.NAMESPACED_DECLARATION,
+ Target.REGISTRATION,
+ ]
+
+ # Selector object to determine which operators to generate
+ # registration code for.
+ selector: SelectiveBuilder
+
+ # Whether or not we are actually code-genning for ROCm
+ rocm: bool
+
+ # Whether or not to generate symint registrations or not. External users
+ # of codegen who don't care about symints can set this to false to get
+ # non-SymInt codegen
+ symint: bool
+
+ # The class that all unstructured native functions live under. This is used to improve
+ # compiler error messages when a kernel writer adds a native function with the wrong signature.
+ # This is only used in unstructured kernels, since structured kernels already live in a class.
+ # Finally, this field is currently Optional because it is only used by external backends.
+ # It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
+ # all of the existing kernel signatures scattered across aten/src/ATen/native.
+ class_method_name: Optional[str]
+
+ # Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
+ # operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
+ skip_dispatcher_op_registration: bool
+
+ @staticmethod
+ def gen_device_check(
+ type: DeviceCheckType, args: List[Argument], method_name: str
+ ) -> str:
+ if type == DeviceCheckType.NoCheck:
+ return " // No device check\n"
+
+ device_check = "c10::optional common_device = nullopt;\n"
+ device_check += "(void)common_device; // Suppress unused variable warning\n"
+ for arg in args:
+ # Only tensor like arguments are eligible
+ if arg.type.is_tensor_like():
+ device_check += f"""
+ c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");"""
+ return device_check
+
+ @method_with_native_function
+ def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
+ if isinstance(f, NativeFunctionsGroup):
+ g: NativeFunctionsGroup = f
+ # Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
+ # gen_structured() has special logic to handle auto-generated kernels.
+ if g.structured:
+ return self.gen_structured(g)
+ else:
+ return list(
+ mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())
+ )
+ elif isinstance(f, NativeFunction):
+ r = self.gen_unstructured(f)
+ return [] if r is None else [r]
+ else:
+ assert_never(f)
+
+ def wrapper_kernel_sig(
+ self, f: NativeFunction
+ ) -> Union[NativeSignature, DispatcherSignature]:
+ # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
+ return DispatcherSignature.from_schema(
+ f.func,
+ prefix=f"wrapper_{self.backend_index.dispatch_key}_{f.func.name.overload_name}_",
+ symint=self.symint,
+ )
+
+ def gen_out_inplace_wrapper(
+ self, f: NativeFunction, g: Optional[NativeFunctionsGroup]
+ ) -> Optional[str]:
+ if g is None:
+ return None
+ k = f.func.kind()
+ if k is SchemaKind.inplace:
+ copy_op = "at::_copy_from"
+ elif k is SchemaKind.out:
+ copy_op = "at::_copy_from_and_resize"
+ else:
+ raise AssertionError("gen_out_inplace_wrapper called on a functional op")
+
+ sig = self.wrapper_kernel_sig(f)
+ name = sig.name()
+
+ func_res = f"{name}_tmp"
+ return_names = cpp.return_names(f)
+ if len(return_names) > 1:
+ updates = "\n ".join(
+ f"{copy_op}(std::get<{i}>({func_res}), {ret_name});"
+ for i, ret_name in enumerate(return_names)
+ )
+ returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
+ elif len(return_names) == 1:
+ ret_name = return_names[0]
+ updates = f"{copy_op}({func_res}, {ret_name});"
+ returns = ret_name
+ else:
+ assert len(f.func.arguments.out) == 1
+ returns = ""
+ out_arg = f.func.arguments.out[0]
+ if out_arg.type.is_list_like():
+ updates = f"""\
+ for (int64_t i = 0; i < {func_res}.size(); ++i) {{
+ {copy_op}({func_res}[i], {out_arg.name}[i]);
+ }}"""
+ else:
+ updates = f"{copy_op}({func_res}, {out_arg.name});"
+
+ functional_sig = self.wrapper_kernel_sig(g.functional)
+ wrapper_name = sig.name()
+
+ return f"""\
+{sig.defn(name=wrapper_name)} {{
+ auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))});
+ {updates}
+ return {returns};
+}}
+"""
+
+ def gen_structured(self, g: NativeFunctionsGroup) -> List[str]:
+ metadata = self.backend_index.get_kernel(g)
+ if self.backend_index.dispatch_key == DispatchKey.Meta:
+ assert not self.backend_index.has_kernel(g.out), (
+ "Do not explicitly specify Meta dispatch key on structured "
+ "functions, they will be automatically generated for you"
+ )
+ elif (
+ self.backend_index.dispatch_key
+ == DispatchKey.CompositeExplicitAutogradNonFunctional
+ ):
+ assert not self.backend_index.has_kernel(g.out), (
+ "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured "
+ "functions, they will be automatically generated for you"
+ )
+ elif metadata is None or not metadata.structured:
+ return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()))
+ structured_gen = StructuredRegisterDispatchKey(
+ self.backend_index,
+ self.target,
+ self.selector,
+ self.rocm,
+ self.symint,
+ self.class_method_name,
+ self.skip_dispatcher_op_registration,
+ g,
+ )
+ return list(mapMaybe(structured_gen.gen_one, g.functions()))
+
+ def gen_unstructured(
+ self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None
+ ) -> Optional[str]:
+ with native_function_manager(f):
+ inplace_meta = False
+ gets_out_inplace_wrapper = False
+ if not self.backend_index.has_kernel(f):
+ if (
+ self.backend_index.dispatch_key == DispatchKey.Meta
+ and f.func.kind() is SchemaKind.inplace
+ and
+ # Defer to composites for meta implementation
+ not f.has_composite_kernel
+ and
+ # Inplace list operations are not supported
+ len(f.func.returns) == 1
+ ):
+ inplace_meta = True
+ elif (
+ not self.backend_index.use_out_as_primary
+ and g is not None
+ and gets_generated_out_inplace_wrapper(f, g, self.backend_index)
+ ):
+ # We want to generate inplace/out wrappers, that don't have a kernel for the backend.
+ gets_out_inplace_wrapper = True
+ else:
+ return None
+ if f.manual_kernel_registration:
+ return None
+
+ if (
+ self.target is Target.REGISTRATION
+ and not self.selector.is_native_function_selected(f)
+ ):
+ return None
+
+ sig = self.wrapper_kernel_sig(f)
+
+ name = sig.name()
+ returns_type = sig.returns_type().cpp_type()
+ args = sig.arguments()
+ args_str = ", ".join(a.defn() for a in args)
+
+ # See Note [Direct dispatch bindings]
+ cpp_sig_group = CppSignatureGroup.from_native_function(
+ f, method=False, fallback_binding=False
+ )
+
+ # TODO: dedupe this with the structured codegen
+ if self.target is Target.NAMESPACED_DECLARATION:
+ result = ""
+ for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
+ result += f"TORCH_API {cpp_sig.decl()};\n"
+ return result
+ elif self.target is Target.NAMESPACED_DEFINITION:
+
+ def generate_defn(cpp_sig: CppSignature) -> str:
+ return f"""
+{cpp_sig.defn()} {{
+return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
+}}
+"""
+
+ result = ""
+ for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
+ result += generate_defn(cpp_sig)
+ return result
+
+ elif self.target is Target.ANONYMOUS_DEFINITION:
+ # short circuit for inplace_meta
+ if inplace_meta:
+ assert f.func.arguments.self_arg is not None
+ self_arg_name = f.func.arguments.self_arg.argument.name
+ # TODO: handle in place on tensor list
+ return f"""
+{returns_type} {name}({args_str}) {{
+ TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
+ "Cannot inplace into non-meta tensor with meta tensor argument");
+ return {self_arg_name};
+}}
+"""
+
+ # short circuit for generated inplace/out wrappers
+ if gets_out_inplace_wrapper:
+ return self.gen_out_inplace_wrapper(f, g)
+
+ metadata = self.backend_index.get_kernel(f)
+ if metadata is None:
+ return None
+ if self.class_method_name is None:
+ impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}"
+ else:
+ impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}"
+
+ kernel_sig = kernel_signature(f, self.backend_index)
+
+ args_exprs_str = ", ".join(
+ e.expr
+ for e in translate(
+ sig.arguments(), kernel_sig.arguments(), method=False
+ )
+ )
+
+ device_check = " // No device check\n"
+ # Backends that require device guards presumably also require device checks.
+ if self.backend_index.device_guard:
+ device_check_args = itertools.chain(
+ f.func.arguments.out, f.func.arguments.flat_positional
+ )
+ device_check = RegisterDispatchKey.gen_device_check(
+ f.device_check, list(device_check_args), name
+ )
+
+ device_guard = "// DeviceGuard omitted" # default
+ if f.device_guard and self.backend_index.device_guard:
+ has_tensor_options = any(
+ isinstance(a, TensorOptionsArguments)
+ for a in f.func.arguments.non_out
+ )
+ if has_tensor_options:
+ # kernel is creating a tensor
+ device_guard = """
+ const DeviceGuard device_guard(device_or_default(device));"""
+
+ # CUDA requires special handling
+ if is_cuda_dispatch_key(self.backend_index.dispatch_key):
+ device_guard = (
+ f"globalContext().lazyInitCUDA();\n{device_guard}"
+ )
+ else:
+ # kernel is operating on existing tensors
+
+ # There is precedence for which argument we use to do
+ # device guard. This describes the precedence order.
+ self_arg = (
+ [f.func.arguments.self_arg.argument]
+ if f.func.arguments.self_arg is not None
+ else []
+ )
+ candidate_args = itertools.chain(
+ self_arg,
+ f.func.arguments.out,
+ f.func.arguments.flat_positional,
+ )
+
+ # Only tensor like arguments are eligible
+ device_of = next(
+ (
+ f"{a.name}"
+ for a in candidate_args
+ if a.type.is_tensor_like()
+ ),
+ None,
+ )
+ if device_of is not None:
+ device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"
+
+ return f"""\
+namespace {{
+
+{returns_type} {name}({args_str}) {{
+ {device_check}
+
+ {device_guard}
+ return {impl_name}({args_exprs_str});
+}}
+
+}} // anonymous namespace
+"""
+
+ elif self.target is Target.REGISTRATION:
+ if f.manual_kernel_registration or self.skip_dispatcher_op_registration:
+ return None
+ else:
+ payload = f"TORCH_FN({name})"
+ return f'm.impl("{f.func.name}",\n{payload});\n'
+ else:
+ assert_never(self.target)
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# STRUCTURED
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+@dataclass(frozen=True)
+class StructuredRegisterDispatchKey(RegisterDispatchKey):
+ g: NativeFunctionsGroup
+
+ def gen_class_set_output_functions(
+ self, k: SchemaKind, parent_class: str, generate_super: bool
+ ) -> str:
+ if generate_super:
+ set_output_super = f"{parent_class}::set_output_raw_strided(output_idx, sizes, strides, options, names);"
+ else:
+ set_output_super = ""
+
+ def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str:
+ return f"""
+void set_output_{name}(
+ int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
+ TensorOptions options, DimnameList names
+) override {{
+{textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")}
+ if (!names.empty()) {{
+ namedinference::propagate_names(outputs_[output_idx], names);
+ }}
+ // super must happen after, so that downstream can use maybe_get_output
+ // to retrieve the output
+{textwrap.indent(set_output_super, " ")}
+}}
+"""
+
+ return f"""
+{gen_set_output_function("strided", maybe_create_proxy=True)}
+{gen_set_output_function("raw_strided", maybe_create_proxy=False)}
+"""
+
+ def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str:
+ if self.backend_index.dispatch_key in [
+ DispatchKey.CUDA,
+ DispatchKey.MPS,
+ DispatchKey.CompositeExplicitAutogradNonFunctional,
+ ]:
+ maybe_set_guard = """
+auto current_device = guard_.current_device();
+if (C10_UNLIKELY(current_device.has_value())) {
+ TORCH_INTERNAL_ASSERT(*current_device == options.device(),
+ "structured kernels don't support multi-device outputs");
+} else {
+ guard_.reset_device(options.device());
+}
+"""
+ maybe_set_guard_line = maybe_set_guard + "\n"
+ else:
+ maybe_set_guard_line = maybe_set_guard = ""
+
+ if maybe_create_proxy:
+ create_proxy = """
+auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
+if (C10_UNLIKELY(maybe_proxy.has_value())) {
+ proxy_outputs_[output_idx] = std::move(maybe_proxy).value();
+}
+"""
+ else:
+ create_proxy = ""
+
+ if k is SchemaKind.functional:
+ assert self.backend_index.dispatch_key in (
+ DispatchKey.Meta,
+ DispatchKey.CPU,
+ DispatchKey.CUDA,
+ DispatchKey.MPS,
+ DispatchKey.CompositeExplicitAutogradNonFunctional,
+ )
+ return f"""{maybe_set_guard_line}
+outputs_[output_idx] = create_out(sizes, strides, options);"""
+ elif k is SchemaKind.inplace:
+ return f"""{maybe_set_guard_line}
+const auto& out = outputs_[output_idx].get();
+check_inplace(out, sizes, options);
+{create_proxy}"""
+ elif k is SchemaKind.out:
+ return f"""{maybe_set_guard_line}
+const auto& out = outputs_[output_idx].get();
+resize_out(out, sizes, strides, options);
+{create_proxy}"""
+ elif k is SchemaKind.mutable or k is SchemaKind.scratch:
+ raise AssertionError(
+ f"{k} structured operators are currently not supported"
+ )
+ else:
+ assert_never(k)
+
+ # returns the definition of a ctor, as well as how to construct
+ # this class to a variable named op
+ def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str:
+ if k is SchemaKind.functional:
+ return ""
+ elif k is SchemaKind.inplace:
+ # TODO: Make sure out argument is guaranteed to be self
+ return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}"
+ elif k is SchemaKind.out:
+ out_args = ", ".join(f"Tensor& out{i}" for i in range(returns))
+ out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns))
+ return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}"
+ elif k is SchemaKind.mutable or k is SchemaKind.scratch:
+ raise AssertionError(
+ f"{k} structured operators are currently not supported"
+ )
+ else:
+ assert_never(k)
+
+ def gen_class(
+ self,
+ f: NativeFunction,
+ k: SchemaKind,
+ *,
+ class_name: str,
+ parent_class: str,
+ generate_super: bool,
+ ) -> str:
+ if k is SchemaKind.functional:
+ output_type = "Tensor"
+ output_value = "outputs_[output_idx]"
+ proxy_field = ""
+ elif k is SchemaKind.inplace:
+ output_type = "std::reference_wrapper"
+ output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
+ proxy_field = f"std::array, {len(f.func.returns)}> proxy_outputs_;"
+ elif k is SchemaKind.out:
+ output_type = "std::reference_wrapper"
+ output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
+ proxy_field = f"std::array, {len(f.func.returns)}> proxy_outputs_;"
+
+ if self.backend_index.dispatch_key == DispatchKey.CUDA:
+ if self.rocm:
+ guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;"
+ else:
+ guard_field = "c10::cuda::OptionalCUDAGuard guard_;"
+ elif (
+ self.backend_index.dispatch_key
+ == DispatchKey.CompositeExplicitAutogradNonFunctional
+ ):
+ guard_field = "c10::OptionalDeviceGuard guard_;"
+ elif self.backend_index.dispatch_key == DispatchKey.MPS:
+ # TODO: Move to OptionalMPSGuard.
+ guard_field = "c10::OptionalDeviceGuard guard_;"
+ else:
+ guard_field = ""
+
+ indent = " " * 4
+ class_ctor_str = self.gen_class_ctor(k, class_name, len(f.func.returns))
+ lines = (
+ f"struct {class_name} final : public {parent_class} {{",
+ f"{textwrap.indent(class_ctor_str, indent)}",
+ f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}",
+ " const Tensor& maybe_get_output(int64_t output_idx) override {",
+ f" return {output_value};\n", # type: ignore[possibly-undefined] # TODO: audit
+ " }",
+ f" std::array<{output_type}, {len(f.func.returns)}> outputs_;", # type: ignore[possibly-undefined] # TODO: audit
+ f"{textwrap.indent(proxy_field, indent)}", # type: ignore[possibly-undefined] # TODO: audit
+ f"{textwrap.indent(guard_field, indent)}",
+ "};",
+ )
+ return "\n".join(line for line in lines if line)
+
+ @method_with_native_function
+ def gen_one(self, f: NativeFunction) -> Optional[str]:
+ assert not f.manual_kernel_registration
+
+ if (
+ self.target is Target.REGISTRATION
+ and not self.selector.is_native_function_selected(f)
+ ):
+ return None
+
+ # TODO: Now, there is something interesting going on here. In the code below,
+ # we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace
+ # based on the out implementation. But in fact, out is definable by
+ # functional too (just not very efficiently), and this is honestly the
+ # MORE likely situation for a backend implementor. How do we pick?
+ # Well, taking a page from Haskell type classes and default methods,
+ # we could conceivably register a circular definition (out in terms
+ # of functional, and functional in terms of out) and just require
+ # someone to implement one or the other. We'd have to do a little bit
+ # of work to not register one of these "weak" definitions unless there
+ # is a strong definition somewhere in the DAG! So it's not implemented yet.
+ if (
+ self.backend_index.dispatch_key
+ == DispatchKey.CompositeExplicitAutogradNonFunctional
+ and f.func.kind() is SchemaKind.out
+ ):
+ # Never generate a default implementation for out, that's what you
+ # have to define as a backend implementor
+ return None
+
+ # Note [Direct dispatch bindings]
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ # Signature of the non-dispatched function we'll expose in a header
+ # (e.g., at::cpu::add). We don't generate methods (TODO: do this
+ # when CPUTensor class is a thing); nor do we generate fallback
+ # bindings for manual_cpp_binding functions.
+ cpp_sig_group = CppSignatureGroup.from_native_function(
+ f, method=False, fallback_binding=False
+ )
+
+ # Signature of the wrapper function we'll register to the dispatcher
+ kern = self.backend_index.get_kernel(f)
+ sig = NativeSignature(
+ f.func,
+ prefix=f"wrapper_{self.backend_index.dispatch_key}_",
+ symint=kern is not None and kern.supports_symint(),
+ )
+
+ if self.target is Target.NAMESPACED_DECLARATION:
+ result = ""
+ for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
+ result += f"TORCH_API {cpp_sig.decl()};\n"
+ return result
+
+ elif self.target is Target.NAMESPACED_DEFINITION:
+
+ def generate_defn(cpp_sig: CppSignature) -> str:
+ return f"""
+{cpp_sig.defn()} {{
+return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
+}}
+"""
+
+ result = ""
+ for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
+ result += generate_defn(cpp_sig)
+ return result
+
+ elif self.target is Target.ANONYMOUS_DEFINITION:
+ k = f.func.kind()
+
+ # Construct the body of the wrapper function with signature sig
+ sig_body = []
+ # We'll use context to keep track of any variables we've brought
+ # into scope while generating code
+ context: List[Union[Binding, Expr]] = list(sig.arguments())
+
+ # Initialize the class corresponding to this structured
+ # operator; feeding it the output argument(s) if it is known
+ if self.backend_index.dispatch_key is DispatchKey.Meta:
+ class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
+ parent_class = f"at::meta::structured_{meta.name(self.g)}"
+ elif (
+ self.backend_index.dispatch_key
+ is DispatchKey.CompositeExplicitAutogradNonFunctional
+ ):
+ # TODO: dedup this branch
+ class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
+ parent_class = f"at::meta::structured_{meta.name(self.g)}"
+ else:
+ metadata = self.backend_index.get_kernel(self.g)
+ assert metadata is not None
+ class_name = f"structured_{metadata.kernel}_{k.name}"
+ parent_class = f"{metadata.cpp_namespace}::structured_{metadata.kernel}"
+
+ if self.backend_index.device_guard:
+ device_check_args = itertools.chain(
+ f.func.arguments.out, f.func.arguments.flat_positional
+ )
+ sig_body.append(
+ RegisterDispatchKey.gen_device_check(
+ f.device_check, list(device_check_args), sig.name()
+ )
+ )
+
+ if k is SchemaKind.functional:
+ sig_body.append(f"{class_name} op;")
+ elif k is SchemaKind.inplace:
+ sig_body.append(f"{class_name} op(self);")
+ elif k is SchemaKind.out:
+ out_args_str = ", ".join(a.name for a in f.func.arguments.out)
+ sig_body.append(f"{class_name} op({out_args_str});")
+
+ # Translate the input native arguments into structured
+ # arguments for the meta call
+ meta_exprs = ", ".join(
+ e.expr
+ for e in translate(
+ context, structured.meta_arguments(self.g), method=False
+ )
+ )
+
+ if self.g.out.precomputed:
+ # If this function group has precomputed elements, the meta function
+ # returns a struct containing them which must be saved so that it
+ # can be unpacked when generating code to call the impl.
+ sig_body.append(f"auto precompute = op.meta({meta_exprs});")
+
+ # Put all of the contents of the precompute struct into the context
+ # so that translate will be able to return the correct args for the
+ # call to the impl.
+ precomputed_values = [
+ *self.g.out.precomputed.replace.values(),
+ self.g.out.precomputed.add,
+ ]
+ for precomputed_elems in precomputed_values:
+ for arg in precomputed_elems:
+ context.append(
+ Expr(
+ expr=f"precompute.{arg.name}",
+ type=structured.argument_type(arg, binds=arg.name),
+ )
+ )
+
+ # Add a use of the precompute struct so FB internal compilers don't
+ # complain that there is an unused variable.
+ sig_body.append("(void)precompute;")
+ else:
+ sig_body.append(f"op.meta({meta_exprs});")
+
+ # After running meta, op.outputs_ is guaranteed to be valid;
+ # add it to the context
+ out_args = structured.out_arguments(self.g)
+ for i, out_arg in enumerate(out_args):
+ assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type
+
+ if k is SchemaKind.out:
+ expr = f"op.maybe_get_output({i})"
+ else:
+ expr = f"op.outputs_[{i}]"
+
+ context.append(
+ Expr(
+ expr=expr,
+ # TODO: Stop hardcoding that the output type is a Tensor. Note
+ # that for the codegen here this is fine because outputs_ is
+ # hardcoded to be tensor already
+ type=NamedCType(
+ out_arg.nctype.name, MutRefCType(BaseCType(tensorT))
+ ),
+ )
+ )
+
+ # With the expanded context, do the impl call (if not a meta
+ # function)
+ if (
+ self.backend_index.dispatch_key
+ == DispatchKey.CompositeExplicitAutogradNonFunctional
+ ):
+ # TODO: https://github.com/pytorch/pytorch/issues/53023
+ out_sig_group = CppSignatureGroup.from_native_function(
+ self.g.out, method=False, fallback_binding=f.manual_cpp_binding
+ )
+ out_sig = out_sig_group.most_faithful_signature()
+ api_name = out_sig.name()
+ out_exprs = ", ".join(
+ e.expr
+ for e in translate(context, out_sig.arguments(), method=False)
+ )
+ # TODO: I think this means structured won't work with method
+ # only functions (but maybe you're saved by faithful? iunno.)
+ # NB: Originally I wrote this as an at::redispatch call, but
+ # I got in trouble because that meant I needed a DispatchKeySet
+ # in the wrapper function, which meant I needed a DispatchKeySet
+ # in the DispatchKeyFunctions declarations, but the defined API
+ # there does NOT permit a dispatch key set. I think you can
+ # probably unwind this by calling some function to do the TLS
+ # fetch and get the DispatchKeySet when you don't have it, but
+ # I didn't do it for this version
+ sig_body.append(f"at::{api_name}({out_exprs});")
+ elif self.backend_index.dispatch_key != DispatchKey.Meta:
+ impl_exprs = ", ".join(
+ e.expr
+ for e in translate(
+ context, structured.impl_arguments(self.g), method=False
+ )
+ )
+ sig_body.append(f"op.impl({impl_exprs});")
+
+ # Go over each output, and check if there is a proxy created for it.
+ # If so, copy it over to the original output.
+ if k is SchemaKind.out or k is SchemaKind.inplace:
+ for i in range(len(f.func.returns)):
+ sig_body.append(
+ f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);"
+ )
+
+ # Destructively return the final tensors
+ # TODO: Do this in translate instead
+ if k is SchemaKind.functional:
+ if len(f.func.returns) == 1:
+ ret_expr = "std::move(op.outputs_[0])" # small optimization
+ else:
+ moved = ", ".join(
+ f"std::move(op.outputs_[{i}])"
+ for i in range(len(f.func.returns))
+ )
+ ret_expr = f"std::make_tuple({moved})"
+ elif k is SchemaKind.inplace:
+ ret_expr = "self"
+ elif k is SchemaKind.out:
+ if len(f.func.returns) == 1:
+ ret_expr = f.func.arguments.out[0].name
+ else:
+ refs = ", ".join(a.name for a in f.func.arguments.out)
+ ret_expr = f"std::forward_as_tuple({refs})"
+ sig_body.append(f"return {ret_expr};") # type: ignore[possibly-undefined] # TODO: audit
+
+ sig_body_str = "\n".join(sig_body)
+
+ # For an overview of what this template code looks like, see
+ # https://github.com/pytorch/rfcs/pull/9
+ return f"""\
+{self.gen_class(
+f, k,
+class_name=class_name,
+parent_class=parent_class,
+generate_super=self.g.out.structured_inherits is not None
+)}
+
+{sig.defn()} {{
+{sig_body_str}
+}}
+"""
+
+ elif self.target is Target.REGISTRATION:
+ return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
+ else:
+ assert_never(self.target)
+ # Silence mypy's "Missing return statement" error
+ return None
diff --git a/MLPY/Lib/site-packages/torchgen/dest/ufunc.py b/MLPY/Lib/site-packages/torchgen/dest/ufunc.py
new file mode 100644
index 0000000000000000000000000000000000000000..52268451ccfd19ea683fac751132581b1d3115db
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/dest/ufunc.py
@@ -0,0 +1,545 @@
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Sequence, Tuple, Union
+
+import torchgen.api.ufunc as ufunc
+from torchgen.api.translate import translate
+from torchgen.api.types import (
+ BaseCType,
+ Binding,
+ CType,
+ Expr,
+ NamedCType,
+ opmath_t,
+ scalar_t,
+ StructuredImplSignature,
+ VectorizedCType,
+)
+from torchgen.api.ufunc import UfunctorBindings
+from torchgen.context import with_native_function
+from torchgen.model import (
+ Argument,
+ BaseTy,
+ BaseType,
+ DispatchKey,
+ NativeFunctionsGroup,
+ ScalarType,
+ UfuncKey,
+)
+from torchgen.utils import OrderedSet
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# CUDA STUFF
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+# NB: not bothering to generate dispatch stub forward declaration in header,
+# we can just paste it whereever necessary
+
+# TODO: use BackendIndex
+# dispatch_key: DispatchKey # only CPU/CUDA right now
+
+
+# Represents functors for implementing CUDA ufuncs.
+# Functors are templated by scalar_t because when USERS instantiate functors
+# they are templated. A functor looks something like this:
+#
+# template
+# struct CUDAFunctorOnSelf_add {
+# using opmath_t = at::opmath_type;
+# opmath_t other_;
+# opmath_t alpha_;
+# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
+# : other_(other), alpha_(alpha) {}
+# __device__ scalar_t operator()(scalar_t self) {
+# return ufunc::add(static_cast(self), other_, alpha_);
+# }
+# };
+#
+@dataclass(frozen=True)
+class UfunctorSignature:
+ g: NativeFunctionsGroup
+ scalar_tensor_idx: Optional[int]
+ name: str
+
+ def arguments(self) -> UfunctorBindings:
+ return ufunc.ufunctor_arguments(
+ self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
+ )
+
+ def fields(self) -> List[Binding]:
+ # fields are renamed to have a trailing underscore, as is conventional
+ return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
+
+ def returns_type(self) -> CType:
+ # TODO: don't hardcode; return type will be inferred based on tags on
+ # the native function
+ return BaseCType(scalar_t)
+
+ def decl_fields(self) -> str:
+ return "\n".join(f"{f.type} {f.name};" for f in self.fields())
+
+ def inline_defn_ctor(self) -> str:
+ args_str = ", ".join(a.decl() for a in self.arguments().ctor)
+ # NB: hypothetically could do this with translate but the
+ # transition here is very regular
+ init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
+ return f"{self.name}({args_str}) : {init_str} {{}}"
+
+ def decl_apply(self) -> str:
+ args_str = ", ".join(a.decl() for a in self.arguments().apply)
+ return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
+
+
+@dataclass(frozen=True)
+class UfuncSignature:
+ g: NativeFunctionsGroup
+ name: str
+ compute_t: CType
+
+ def arguments(self) -> List[Binding]:
+ return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
+
+ def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str:
+ return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
+
+
+# steps:
+# 1. take the functional signature
+# 2. use api.ufunc to convert it to template signature. this establishes
+# the type of the template function
+# 3. use api.ufunc (II) to generate a split struct / operator() signature.
+# this establish context in which we call the template signature
+#
+# StructuredImplSignature context
+# ~> functor constructor sig
+#
+# Functor constructor context
+# ~> functor fields sig
+#
+# Functor apply context (functor fields + functor apply sig)
+# ~> template sig
+#
+
+
+def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
+ num_tensors = sum(
+ 1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
+ )
+ return num_tensors == 2
+
+
+def compute_ufunc_cuda_functors(
+ g: NativeFunctionsGroup,
+) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]:
+ # First, build the functors.
+ ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {}
+ ufunctors: List[str] = []
+ loops = g.out.ufunc_inner_loop
+ scalar_tensor_idx_lookup = {
+ UfuncKey.CUDAFunctorOnSelf: 1,
+ UfuncKey.CUDAFunctorOnOther: 0,
+ UfuncKey.CUDAFunctor: None,
+ }
+ if eligible_for_binary_scalar_specialization(g):
+ keys = [
+ UfuncKey.CUDAFunctorOnSelf,
+ UfuncKey.CUDAFunctorOnOther,
+ UfuncKey.CUDAFunctor,
+ ]
+ else:
+ keys = [UfuncKey.CUDAFunctor]
+ for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
+ assert k not in loops, f"cannot use {k} on non-binary function"
+ for k in keys:
+ # If the key was directly defined, skip functor codegen; we assume the
+ # user already done it for us
+ if k in loops:
+ ufunctor_sig = UfunctorSignature(
+ g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
+ )
+ for dtype in loops[k].supported_dtypes:
+ ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
+ continue
+
+ # Note [ScalarOnly and Generic must match names for CUDA]
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ # Otherwise, look in ANY of the generic entries. For simplicity of
+ # codegen, both ScalarOnly and Generic are defined, the ufunc name
+ # must match (if they didn't match, we'd have to generate distinct
+ # functors per dtype, which is awful, so we're not going to do it unless
+ # someone really forces us to)
+ ufunc_name = None
+ supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
+ for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
+ if lk not in loops:
+ continue
+ if ufunc_name is None:
+ ufunc_name = loops[lk].name
+ else:
+ # See Note [ScalarOnly and Generic must match names for CUDA]
+ assert (
+ ufunc_name == loops[lk].name
+ ), "ScalarOnly and Generic must have same ufunc name"
+ supported_dtypes |= loops[lk].supported_dtypes
+ assert ufunc_name is not None
+
+ name = f"{k}_{ufunc_name}"
+ ufunctor_sig = UfunctorSignature(
+ g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
+ )
+ for dtype in supported_dtypes:
+ ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
+
+ ufunc_sig = UfuncSignature(
+ g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
+ )
+ apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
+ ufunctors.append(
+ f"""
+template
+struct {ufunctor_sig.name} {{
+ using opmath_t = at::opmath_type;
+ {ufunctor_sig.decl_fields()}
+ {ufunctor_sig.inline_defn_ctor()}
+ __device__ {ufunctor_sig.decl_apply()} {{
+ return {ufunc_sig.call(apply_ctx)};
+ }}
+}};
+"""
+ )
+
+ return ufunctor_sigs, "\n".join(ufunctors)
+
+
+@dataclass(frozen=True)
+class BinaryScalarSpecializationConfig:
+ scalar_idx: int
+ ctor_tensor: str
+ ufunc_key: UfuncKey
+
+
+BinaryScalarSpecializationConfigs = [
+ BinaryScalarSpecializationConfig(
+ scalar_idx=0,
+ ctor_tensor="self",
+ ufunc_key=UfuncKey.CUDAFunctorOnOther,
+ ),
+ BinaryScalarSpecializationConfig(
+ scalar_idx=1,
+ ctor_tensor="other",
+ ufunc_key=UfuncKey.CUDAFunctorOnSelf,
+ ),
+]
+
+
+def compute_ufunc_cuda_dtype_body(
+ g: NativeFunctionsGroup,
+ dtype: ScalarType,
+ inner_loops: Dict[UfuncKey, UfunctorSignature],
+ parent_ctx: Sequence[Binding],
+) -> str:
+ body = "using opmath_t = at::opmath_type;"
+ body += "if (false) {}\n" # for ease of codegen
+ for config in BinaryScalarSpecializationConfigs:
+ if config.ufunc_key not in inner_loops:
+ continue
+ ufunctor_sig = inner_loops[config.ufunc_key]
+ scalar_idx = config.scalar_idx + 1
+ # Make a copy and at the same time widen the type (not permissible
+ # without copy; we don't want to mutate the input argument anyway)
+ ctx: List[Union[Expr, Binding]] = list(parent_ctx)
+ ctx.append(
+ Expr(
+ expr=f"iter.scalar_value({scalar_idx})",
+ type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
+ )
+ )
+ ufunctor_ctor_exprs_str = ", ".join(
+ a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
+ )
+
+ # NB: ufunctor must be allocated before iter.remove_operand is called,
+ # as it relies on iter
+ body += f"""\
+else if (iter.is_cpu_scalar({scalar_idx})) {{
+ {ufunctor_sig.name} ufunctor({ufunctor_ctor_exprs_str});
+ iter.remove_operand({scalar_idx});
+ gpu_kernel(iter, ufunctor);
+}}"""
+
+ ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
+ ufunctor_ctor_exprs_str = ", ".join(
+ a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
+ )
+ body += f"""
+else {{
+ gpu_kernel(iter, {ufunctor_sig.name}({ufunctor_ctor_exprs_str}));
+}}
+ """
+ return body
+
+
+@with_native_function
+def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
+ # First, build the functors, indexing them by dtype
+ ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
+
+ # Next, build the conditionals
+ sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
+ dtype_cases = []
+ for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
+ dtype_cases.append(
+ f"""
+AT_DISPATCH_CASE(at::ScalarType::{dtype},
+ [&]() {{
+ {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
+ }}
+)
+"""
+ )
+
+ dtype_cases_str = "\n".join(dtype_cases)
+
+ stub_sig = StubSignature(g)
+
+ return f"""
+{ufunctors}
+
+{stub_sig.type_defn()};
+{stub_sig.dispatch_decl()};
+
+{stub_sig.kernel_defn()} {{
+ AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
+ {dtype_cases_str}
+ );
+}}
+REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
+
+{sig.defn()} {{
+ {stub_sig.direct_call(sig.arguments())};
+}}
+"""
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# CPU STUFF
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+@dataclass(frozen=True)
+class StubSignature:
+ g: NativeFunctionsGroup
+
+ @property
+ def name(self) -> str:
+ return f"{str(self.g.functional.func.name.name)}_stub"
+
+ @property
+ def kernel_name(self) -> str:
+ return f"{str(self.g.functional.func.name.name)}_kernel"
+
+ @property
+ def type_name(self) -> str:
+ return f"{str(self.g.functional.func.name.name)}_fn"
+
+ def arguments(self) -> List[Binding]:
+ return ufunc.stub_arguments(self.g)
+
+ def type(self) -> str:
+ cpp_args = self.arguments()
+ return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
+
+ def dispatch_decl(self) -> str:
+ return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
+
+ def dispatch_defn(self) -> str:
+ return f"DEFINE_DISPATCH({self.name})"
+
+ def kernel_defn(self) -> str:
+ return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
+
+ def type_defn(self) -> str:
+ return f"using {self.type_name} = {self.type()}"
+
+ # must be called from context where this is TensorIteratorBase*
+ def call(self, ctx: Sequence[Binding]) -> str:
+ return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
+
+ # used in CUDA to skip the unnecessary dynamic dispatch
+ def direct_call(self, ctx: Sequence[Binding]) -> str:
+ return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
+
+
+@with_native_function
+def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
+ stub_sig = StubSignature(g)
+ sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
+
+ return f"""
+{stub_sig.type_defn()};
+{stub_sig.dispatch_decl()};
+{stub_sig.dispatch_defn()};
+
+{sig.defn()} {{
+ {stub_sig.call(sig.arguments())};
+}}
+"""
+
+
+def compute_ufunc_cpu_dtype_body(
+ g: NativeFunctionsGroup,
+ dtype: ScalarType,
+ inner_loops: Dict[UfuncKey, UfuncSignature],
+ parent_ctx: Sequence[Binding],
+) -> str:
+ assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
+ assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
+ scalar_loop = inner_loops[UfuncKey.CPUScalar]
+ vec_loop = None
+ if UfuncKey.CPUVector in inner_loops:
+ vec_loop = inner_loops[UfuncKey.CPUVector]
+
+ # NB: We DON'T use translate here, because translate is
+ # incapable of CSE'ing the scalar accesses in case it is also
+ # used by Vectorized; also, the unpacking here is very simple
+ # and only affects Scalar; everything else is implicitly captured
+ # by the lambda
+
+ # Setup scalar in scope
+ body = []
+ ctx = []
+ for b in parent_ctx:
+ if isinstance(b.argument, Argument) and b.argument.type != BaseType(
+ BaseTy.Scalar
+ ):
+ continue
+ body.append(f"auto _s_{b.name} = {b.name}.to();")
+ ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
+ if vec_loop is not None:
+ for b in parent_ctx:
+ if isinstance(b.argument, Argument) and b.argument.type != BaseType(
+ BaseTy.Scalar
+ ):
+ continue
+ body.append(
+ f"auto _v_{b.name} = at::vec::Vectorized(_s_{b.name});"
+ )
+ ctx.append(
+ Expr(
+ f"_v_{b.name}",
+ NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
+ )
+ )
+
+ # Setup lambda signature
+ # NB: simplified version of ufunctor_arguments
+ scalar_bindings = []
+ vec_bindings = []
+ for a in g.functional.func.arguments.flat_non_out:
+ if not a.type.is_tensor_like():
+ continue
+ assert a.type == BaseType(BaseTy.Tensor)
+ scalar_bindings.append(
+ Binding(
+ name=a.name,
+ nctype=NamedCType(a.name, BaseCType(scalar_t)),
+ argument=a,
+ )
+ )
+ if vec_loop is not None:
+ vec_bindings.append(
+ Binding(
+ name=a.name,
+ nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
+ argument=a,
+ )
+ )
+
+ def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]:
+ r: List[Union[Expr, Binding]] = []
+ r.extend(ctx)
+ r.extend(b)
+ return r
+
+ body_str = "\n".join(body)
+ if vec_loop is not None:
+ return f"""
+{body_str}
+cpu_kernel_vec(iter,
+ [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
+ [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
+);
+"""
+ else:
+ return f"""
+{body_str}
+cpu_kernel(iter,
+ [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
+);
+"""
+
+
+@with_native_function
+def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
+ stub_sig = StubSignature(g)
+
+ # Reindex the ufunc by dtypes; processing generic/scalaronly as well
+ loops = g.out.ufunc_inner_loop
+ ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {}
+ for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
+ lks = []
+ # ORDER MATTERS: this specifies overriding precedence
+ if k in loops: # should happen rarely
+ lks.append(k)
+ if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
+ lks.append(UfuncKey.ScalarOnly)
+ if UfuncKey.Generic in loops:
+ lks.append(UfuncKey.Generic)
+ # TODO: don't hardcode ufunc:: namespace here, should be centralized smh
+ for lk in lks:
+ for dtype in loops[lk].supported_dtypes:
+ compute_t: CType
+ if k is UfuncKey.CPUScalar:
+ compute_t = BaseCType(scalar_t)
+ elif k is UfuncKey.CPUVector:
+ compute_t = VectorizedCType(BaseCType(scalar_t))
+ else:
+ raise AssertionError()
+ inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
+ if k not in inner_ufunc_sigs:
+ inner_ufunc_sigs[k] = UfuncSignature(
+ g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
+ )
+
+ # Build the conditionals
+ dtype_cases = []
+ for dtype, inner_ufunc_sigs in ufunc_sigs.items():
+ dtype_cases.append(
+ f"""
+AT_DISPATCH_CASE(at::ScalarType::{dtype},
+ [&]() {{
+ {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
+ }}
+)
+"""
+ )
+
+ dtype_cases_str = "\n".join(dtype_cases)
+ return f"""
+namespace {{
+
+{stub_sig.kernel_defn()} {{
+ AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
+ {dtype_cases_str}
+ );
+}}
+
+}} // anonymous namespace
+
+{stub_sig.type_defn()};
+{stub_sig.dispatch_decl()};
+REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
+"""
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/__init__.py b/MLPY/Lib/site-packages/torchgen/executorch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/executorch/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55d4acb331d655ed83b137760e9f9a42e4a53e51
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/executorch/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/__pycache__/model.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/executorch/__pycache__/model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2838abb7f2ff928f533ef48eb785b23486fbceb9
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/executorch/__pycache__/model.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/__pycache__/parse.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/executorch/__pycache__/parse.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1fec643674a9ea0bd081348674b3f6d93fcabeba
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/executorch/__pycache__/parse.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/api/__init__.py b/MLPY/Lib/site-packages/torchgen/executorch/api/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..baeecad59b41a42c38ce93e9a63c6758b0884d31
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/api/__pycache__/custom_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/executorch/api/__pycache__/custom_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b8d9ca8a8895e68584df331a6941f80bc4815b1
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/executorch/api/__pycache__/custom_ops.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/api/__pycache__/et_cpp.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/executorch/api/__pycache__/et_cpp.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4f8ab3e0a78ddfd052420d5d5397f1e92938e13
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/executorch/api/__pycache__/et_cpp.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/api/__pycache__/unboxing.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/executorch/api/__pycache__/unboxing.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b6b2ea7ca57d58ec900533657396dfb33dffd598
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/executorch/api/__pycache__/unboxing.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/api/custom_ops.py b/MLPY/Lib/site-packages/torchgen/executorch/api/custom_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f606cbd00d9227437823ea7b1756b6538e8c55b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/executorch/api/custom_ops.py
@@ -0,0 +1,142 @@
+from collections import defaultdict
+
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Sequence, Tuple
+
+from torchgen import dest
+
+# disable import sorting to avoid circular dependency.
+from torchgen.api.types import DispatcherSignature # isort:skip
+from torchgen.context import method_with_native_function
+from torchgen.executorch.model import ETKernelIndex
+from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
+from torchgen.selective_build.selector import SelectiveBuilder
+from torchgen.utils import concatMap, Target
+
+
+# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
+# model authoring side.
+@dataclass(frozen=True)
+class ComputeNativeFunctionStub:
+ @method_with_native_function
+ def __call__(self, f: NativeFunction) -> Optional[str]:
+ if Variant.function not in f.variants:
+ return None
+
+ sig = DispatcherSignature.from_schema(
+ f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
+ )
+ assert sig is not None
+ if len(f.func.returns) == 0:
+ ret_name = ""
+ elif len(f.func.returns) == 1:
+ if f.func.arguments.out:
+ ret_name = f.func.arguments.out[0].name
+ else:
+ ret_name = next(
+ (
+ a.name
+ for a in f.func.arguments.flat_non_out
+ if a.type == f.func.returns[0].type
+ ),
+ "",
+ )
+ if not ret_name:
+ # if return type is tensor
+ if f.func.returns[0].type == BaseType(BaseTy.Tensor):
+ # Returns an empty tensor
+ ret_name = "at::Tensor()"
+ else:
+ raise Exception(f"Can't handle this return type {f.func}")
+ elif len(f.func.arguments.out) == len(f.func.returns):
+ # Returns a tuple of out arguments
+ tensor_type = "at::Tensor &"
+ comma = ", "
+ ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
+ {comma.join([r.name for r in f.func.arguments.out])}
+ )"""
+ else:
+ assert all(
+ a.type == BaseType(BaseTy.Tensor) for a in f.func.returns
+ ), f"Only support tensor returns but got {f.func.returns}"
+ # Returns a tuple of empty tensors
+ tensor_type = "at::Tensor"
+ comma = ", "
+ ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
+ {comma.join(["at::Tensor()" for _ in f.func.returns])}
+ )"""
+ ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else ""
+ return f"""
+{sig.defn()} {{
+ {ret_str}
+}}
+ """
+
+
+def gen_custom_ops_registration(
+ *,
+ native_functions: Sequence[NativeFunction],
+ selector: SelectiveBuilder,
+ kernel_index: ETKernelIndex,
+ rocm: bool,
+) -> Tuple[str, str]:
+ """
+ Generate custom ops registration code for dest.RegisterDispatchKey.
+
+ :param native_functions: a sequence of `NativeFunction`
+ :param selector: for selective build.
+ :param kernel_index: kernels for all the ops.
+ :param rocm: bool for dest.RegisterDispatchKey.
+ :return: generated C++ code to register custom operators into PyTorch
+ """
+
+ # convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet.
+ # TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex.
+
+ dispatch_key = DispatchKey.CPU
+ backend_index = kernel_index._to_backend_index()
+ static_init_dispatch_registrations = ""
+ ns_grouped_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
+ for native_function in native_functions:
+ ns_grouped_native_functions[native_function.namespace].append(native_function)
+
+ for namespace, functions in ns_grouped_native_functions.items():
+ if len(functions) == 0:
+ continue
+ dispatch_registrations_body = "\n".join(
+ list(
+ concatMap(
+ dest.RegisterDispatchKey(
+ backend_index,
+ Target.REGISTRATION,
+ selector,
+ rocm=rocm,
+ symint=False,
+ class_method_name=None,
+ skip_dispatcher_op_registration=False,
+ ),
+ functions,
+ )
+ )
+ )
+ static_init_dispatch_registrations += f"""
+TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
+{dispatch_registrations_body}
+}};"""
+ anonymous_definition = "\n".join(
+ list(
+ concatMap(
+ dest.RegisterDispatchKey(
+ backend_index,
+ Target.ANONYMOUS_DEFINITION,
+ selector,
+ rocm=rocm,
+ symint=False,
+ class_method_name=None,
+ skip_dispatcher_op_registration=False,
+ ),
+ native_functions,
+ )
+ )
+ )
+ return anonymous_definition, static_init_dispatch_registrations
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/api/et_cpp.py b/MLPY/Lib/site-packages/torchgen/executorch/api/et_cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b65745e277b1272aa8fdfd4cfc85c17ee9256d2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/executorch/api/et_cpp.py
@@ -0,0 +1,368 @@
+from typing import List, Optional, Sequence, Set, Union
+
+from torchgen import local
+from torchgen.api.types import (
+ ArgName,
+ ArrayCType,
+ BaseCType,
+ Binding,
+ ConstRefCType,
+ CType,
+ MutRefCType,
+ NamedCType,
+ SpecialArgName,
+ TupleCType,
+ VectorCType,
+ voidT,
+)
+from torchgen.model import (
+ Argument,
+ Arguments,
+ BaseTy,
+ BaseType,
+ ListType,
+ NativeFunction,
+ OptionalType,
+ Return,
+ SelfArgument,
+ TensorOptionsArguments,
+ Type,
+)
+from torchgen.utils import assert_never
+from .types import (
+ ArrayRefCType,
+ BaseTypeToCppMapping,
+ OptionalCType,
+ scalarT,
+ tensorListT,
+ tensorT,
+)
+
+"""
+This file describes the translation of JIT schema to the public C++ API, which is what people use when they call
+functions like at::add. It also serves as a native function API, which is the signature of kernels,
+since in Executorch CppSignature is the same as NativeSignature.
+
+Difference between this file and torchgen.api.cpp.py:
+
+ - Executorch doesn't support TensorOptions, however in this file we still keep the logic here to be compatible with
+ torchgen.api.cpp, so that we can do stuff like ATen mode (running ATen kernels in Executorch).
+
+ - Executorch doesn't support Dimname.
+
+ - Executorch runtime doesn't support SymInt, will treat it as int.
+"""
+
+
+# Translation of "value types" in JIT schema to C++ API type. Value
+# types look the same no matter if they are argument types or return
+# types. Returns None if the type in question is not a value type.
+def valuetype_type(
+ t: Type,
+ *,
+ binds: ArgName,
+ remove_non_owning_ref_types: bool = False,
+) -> Optional[NamedCType]:
+ if isinstance(t, BaseType):
+ if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
+ return None
+ # For SymInt we simply treat it as int.
+ elif str(t) == "SymInt":
+ return NamedCType(binds, BaseCType(BaseTypeToCppMapping[BaseTy.int]))
+ if remove_non_owning_ref_types:
+ if t.name == BaseTy.str:
+ raise AssertionError(
+ "string ref->value conversion: not implemented yet"
+ )
+ # All other BaseType currently map directly to BaseCppTypes.
+ return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
+ elif isinstance(t, OptionalType):
+ elem = valuetype_type(t.elem, binds=binds)
+ if elem is None:
+ return None
+ return NamedCType(binds, OptionalCType(elem.type))
+ elif isinstance(t, ListType):
+ if str(t.elem) == "bool":
+ assert t.size is not None
+ return NamedCType(
+ binds, ArrayCType(BaseCType(BaseTypeToCppMapping[BaseTy.bool]), t.size)
+ )
+ else:
+ return None
+ else:
+ raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+# Translation of types occurring in JIT arguments to a C++ argument type.
+# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
+# For example, we'll return std::vector instead of IntArrayRef.
+# See Note [translation from C++ reference to value types]
+def argumenttype_type(
+ t: Type,
+ *,
+ mutable: bool,
+ binds: ArgName,
+ remove_non_owning_ref_types: bool = False,
+) -> NamedCType:
+ # If it's a value type, do the value type translation
+ r = valuetype_type(
+ t,
+ binds=binds,
+ remove_non_owning_ref_types=remove_non_owning_ref_types,
+ )
+ if r is not None:
+ return r
+ if isinstance(t, BaseType):
+ if t.name == BaseTy.Tensor:
+ if mutable and not local.use_const_ref_for_mutable_tensors():
+ return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
+ else:
+ return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
+ elif t.name == BaseTy.Scalar:
+ return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
+ else:
+ raise AssertionError(f"base type should have been value type {t}")
+ elif isinstance(t, OptionalType):
+ if str(t.elem) == "Tensor":
+ if mutable and not local.use_const_ref_for_mutable_tensors():
+ return NamedCType(
+ binds, MutRefCType(BaseCType(tensorT))
+ ) # TODO: fix this discrepancy
+ else:
+ return NamedCType(
+ binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
+ )
+ elif str(t.elem) == "Scalar":
+ return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
+ elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
+ return NamedCType(binds, OptionalCType(elem.type))
+ elif isinstance(t, ListType):
+ # TODO: keeping these special cases for Tensor[] and Tensor?[] so that we can hookup with ATen kernels.
+ if str(t.elem) == "Tensor":
+ return NamedCType(binds, BaseCType(tensorListT))
+ elif str(t.elem) == "Dimname":
+ raise NotImplementedError("Executorch doesn't support Dimname")
+ elif str(t.elem) == "Tensor?":
+ return NamedCType(binds, ArrayRefCType(OptionalCType(BaseCType(tensorT))))
+ elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
+ return NamedCType(binds, ArrayRefCType(elem.type))
+ else:
+ raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+# Translate a JIT argument into its C++ type
+def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
+ return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
+
+
+# Translation of a (non-multi) return type from JIT to C++
+# N.B: returntype_type returns a CType, not a NamedCType.
+# This is mostly because of the mismatch between return types and return names.
+# e.g. a function with a return type of 'void' has 0 return names,
+# and a function with a return type of 'std::tuple' has >1 return name.
+def returntype_type(t: Type, *, mutable: bool) -> CType:
+ # placeholder is ignored
+ r = valuetype_type(t, binds="__placeholder__")
+ if r is not None:
+ return r.type
+
+ if isinstance(t, BaseType):
+ if t.name == BaseTy.Tensor:
+ if mutable:
+ if local.use_const_ref_for_mutable_tensors():
+ return ConstRefCType(BaseCType(tensorT))
+ else:
+ return MutRefCType(BaseCType(tensorT))
+ else:
+ # Note [Tensor Copy Returns]
+ # Currently, we use "Argument.is_write" to determine
+ # whether or not Tensor return types should be copies or references.
+ # If that ever changes, take a look at other locations of this note!
+ return BaseCType(tensorT)
+ elif t.name == BaseTy.Scalar:
+ return BaseCType(scalarT)
+ elif isinstance(t, ListType):
+ assert (
+ not mutable
+ ), "Native functions should never return a mutable tensor list. They should return void."
+ elem = returntype_type(t.elem, mutable=False)
+ assert t.size is None, f"fixed size list returns not supported: {t}"
+ return VectorCType(elem)
+
+ raise AssertionError(f"unrecognized return type {t}")
+
+
+# Translation of a single return to its C++ type
+def return_type(r: Return) -> CType:
+ return returntype_type(r.type, mutable=r.is_write)
+
+
+# Translation of a full (possibly multi) return from JIT to its C++ type
+def returns_type(rs: Sequence[Return]) -> CType:
+ if len(rs) == 0:
+ return BaseCType(voidT)
+ elif len(rs) == 1:
+ return return_type(rs[0])
+ else:
+ return TupleCType([return_type(r) for r in rs])
+
+
+def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
+ returns: List[str] = []
+ for i, r in enumerate(f.func.returns):
+ # If we have an inplace function, the return argument is
+ # implicitly named self.
+ # TODO: Consider incorporating this into the data model
+ if f.func.name.name.inplace:
+ assert i == 0, "illegal inplace function with multiple returns"
+ name = "self"
+ # If we are out function, the name is the name of the
+ # corresponding output function (r.name will get recorded
+ # in field_name later.)
+ elif f.func.is_out_fn():
+ name = f.func.arguments.out[i].name
+ # If the return argument is explicitly named...
+ elif r.name:
+ name_conflict = any(
+ r.name == a.name for a in f.func.schema_order_arguments()
+ )
+ if name_conflict and not f.func.is_out_fn():
+ name = f"{r.name}_return"
+ else:
+ name = r.name
+ # If there is no explicit name and no fallback name was passed in, we just name the output result,
+ # unless it's a multi-return, in which case it's result0,
+ # result1, etc (zero-indexed)
+ else:
+ name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
+ returns.append(name)
+ return returns
+
+
+JIT_TO_CPP_DEFAULT = {
+ "False": "false",
+ "True": "true",
+ "None": "torch::executorch::nullopt", # UGH this one is type directed
+ "[]": "{}",
+ "contiguous_format": "torch::executorch::MemoryFormat::Contiguous",
+ "long": "torch::executorch::kLong",
+}
+
+
+# Convert a JIT default into C++ expression representing the default
+def default_expr(d: str, t: Type) -> str:
+ if d == "None" and str(t) == "Tensor?":
+ return "{}"
+ if isinstance(t, BaseType) and t.name is BaseTy.str:
+ # Schema allows single quotes but C++ needs double
+ if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
+ s = ""
+ i = 1
+ while i + 1 < len(d):
+ if d[i] != "\\":
+ if d[i] == '"':
+ s += '\\"'
+ else:
+ s += d[i]
+ i += 1
+ else:
+ if d[i + 1] == "'":
+ s += "'"
+ else:
+ s += d[i : i + 2]
+ i += 2
+
+ return f'"{s}"'
+
+ if isinstance(t, OptionalType):
+ if d == "None":
+ return "torch::executor::nullopt"
+
+ return default_expr(d, t.elem)
+
+ if isinstance(t, ListType):
+ if d.startswith("[") and d.endswith("]"):
+ return "{" + d[1:-1] + "}"
+ elif t.size is None:
+ # NOTE: Sized lists can have scalar defaults
+ raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
+
+ return JIT_TO_CPP_DEFAULT.get(d, d)
+
+
+# Convert an argument into its C++ API form
+
+
+def argument(
+ a: Union[Argument, TensorOptionsArguments, SelfArgument],
+ *,
+ cpp_no_default_args: Set[str],
+ method: bool,
+ faithful: bool,
+ has_tensor_options: bool,
+) -> List[Binding]:
+ def sub_argument(
+ a: Union[Argument, TensorOptionsArguments, SelfArgument]
+ ) -> List[Binding]:
+ return argument(
+ a,
+ cpp_no_default_args=cpp_no_default_args,
+ method=method,
+ faithful=faithful,
+ has_tensor_options=has_tensor_options,
+ )
+
+ if isinstance(a, Argument):
+ binds: ArgName
+ if a.name == "memory_format" and has_tensor_options:
+ binds = SpecialArgName.possibly_redundant_memory_format
+ else:
+ binds = a.name
+ default: Optional[str] = None
+ if a.name not in cpp_no_default_args and a.default is not None:
+ default = default_expr(a.default, a.type)
+ return [
+ Binding(
+ nctype=argument_type(a, binds=binds),
+ name=a.name,
+ default=default,
+ argument=a,
+ )
+ ]
+ elif isinstance(a, TensorOptionsArguments):
+ raise NotImplementedError("Need to implement type resolution for TensorOptions")
+ elif isinstance(a, SelfArgument):
+ if method:
+ # Caller is responsible for installing implicit this in context!
+ return []
+ else:
+ return sub_argument(a.argument)
+ else:
+ assert_never(a)
+
+
+def arguments(
+ arguments: Arguments,
+ *,
+ faithful: bool,
+ method: bool,
+ cpp_no_default_args: Set[str],
+) -> List[Binding]:
+ args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
+ if faithful:
+ args.extend(arguments.non_out)
+ args.extend(arguments.out)
+ else:
+ args.extend(arguments.out)
+ args.extend(arguments.non_out)
+ return [
+ r.no_default() if faithful else r
+ for a in args
+ for r in argument(
+ a,
+ faithful=faithful,
+ method=method,
+ has_tensor_options=arguments.tensor_options is not None,
+ cpp_no_default_args=cpp_no_default_args,
+ )
+ ]
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/api/types/__init__.py b/MLPY/Lib/site-packages/torchgen/executorch/api/types/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..883459aedfda7ec339de17e1f83da5d6f955f297
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/executorch/api/types/__init__.py
@@ -0,0 +1,2 @@
+from .types import *
+from .signatures import * # isort:skip
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/api/types/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/executorch/api/types/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..000dcac3485d852136b63454ca209c3573ed40f0
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/executorch/api/types/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/api/types/__pycache__/signatures.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/executorch/api/types/__pycache__/signatures.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a8b2f4c7c5bf77a4285ef36ac2bf6bc5c6e4417a
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/executorch/api/types/__pycache__/signatures.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/api/types/__pycache__/types.cpython-39.pyc b/MLPY/Lib/site-packages/torchgen/executorch/api/types/__pycache__/types.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f32cb3d3d3acd94415a3d47dc8a8133d0d172f3c
Binary files /dev/null and b/MLPY/Lib/site-packages/torchgen/executorch/api/types/__pycache__/types.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/api/types/signatures.py b/MLPY/Lib/site-packages/torchgen/executorch/api/types/signatures.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c9c4dd95f5d1a85932e71691dce2e12b87077c3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/executorch/api/types/signatures.py
@@ -0,0 +1,73 @@
+from dataclasses import dataclass
+from typing import List, Optional, Set
+
+import torchgen.api.cpp as aten_cpp
+
+from torchgen.api.types import Binding, CType
+from torchgen.model import FunctionSchema, NativeFunction
+
+from .types import contextArg
+
+
+@dataclass(frozen=True)
+class ExecutorchCppSignature:
+ """
+ This signature is merely a CppSignature with Executorch types (optionally
+ contains KernelRuntimeContext as well). The inline definition of
+ CppSignature is generated in Functions.h and it's used by unboxing
+ functions.
+ """
+
+ # The schema this signature is derived from
+ func: FunctionSchema
+
+ # The set of C++ arguments which should not have defaults applied to them
+ cpp_no_default_args: Set[str]
+
+ # Allows you to prepend an arbitrary prefix to the signature name.
+ # This is useful for parts of the codegen that generate wrappers around kernels,
+ # and need to avoid naming collisions.
+ prefix: str = ""
+
+ def arguments(self, *, include_context: bool = True) -> List[Binding]:
+ return ([contextArg] if include_context else []) + et_cpp.arguments(
+ self.func.arguments,
+ faithful=True, # always faithful, out argument at the end
+ method=False, # method not supported
+ cpp_no_default_args=self.cpp_no_default_args,
+ )
+
+ def name(self) -> str:
+ return self.prefix + aten_cpp.name(
+ self.func,
+ faithful_name_for_out_overloads=True,
+ )
+
+ def decl(self, name: Optional[str] = None, *, include_context: bool = True) -> str:
+ args_str = ", ".join(
+ a.decl() for a in self.arguments(include_context=include_context)
+ )
+ if name is None:
+ name = self.name()
+ return f"{self.returns_type().cpp_type()} {name}({args_str})"
+
+ def defn(self, name: Optional[str] = None) -> str:
+ args = [a.defn() for a in self.arguments()]
+ args_str = ", ".join(args)
+ if name is None:
+ name = self.name()
+ return f"{self.returns_type().cpp_type()} {name}({args_str})"
+
+ def returns_type(self) -> CType:
+ return et_cpp.returns_type(self.func.returns)
+
+ @staticmethod
+ def from_native_function(
+ f: NativeFunction, *, prefix: str = ""
+ ) -> "ExecutorchCppSignature":
+ return ExecutorchCppSignature(
+ func=f.func, prefix=prefix, cpp_no_default_args=f.cpp_no_default_args
+ )
+
+
+from torchgen.executorch.api import et_cpp
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/api/types/types.py b/MLPY/Lib/site-packages/torchgen/executorch/api/types/types.py
new file mode 100644
index 0000000000000000000000000000000000000000..28b2f03b4b3e1fdc74511920ef88d0f54981dbe1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/executorch/api/types/types.py
@@ -0,0 +1,81 @@
+from dataclasses import dataclass
+from typing import Dict
+
+from torchgen.api.types import (
+ BaseCppType,
+ BaseCType,
+ Binding,
+ boolT,
+ CType,
+ doubleT,
+ Expr,
+ longT,
+ MutRefCType,
+ NamedCType,
+)
+from torchgen.model import BaseTy
+
+halfT = BaseCppType("torch::executor", "Half")
+bfloat16T = BaseCppType("torch::executor", "BFloat16")
+stringT = BaseCppType("torch::executor", "string_view")
+scalarTypeT = BaseCppType("torch::executor", "ScalarType")
+tensorT = BaseCppType("torch::executor", "Tensor")
+tensorListT = BaseCppType("torch::executor", "TensorList")
+scalarT = BaseCppType("torch::executor", "Scalar")
+memoryFormatT = BaseCppType("torch::executor", "MemoryFormat")
+intArrayRefT = BaseCppType("torch::executor", "IntArrayRef")
+optionalT = BaseCppType("torch::executor", "optional")
+contextT = BaseCppType("torch::executor", "KernelRuntimeContext")
+
+contextExpr = Expr(
+ expr="context",
+ type=NamedCType(name="context", type=MutRefCType(BaseCType(contextT))),
+)
+
+contextArg = Binding(
+ name="context",
+ nctype=contextExpr.type,
+ argument=None, # type: ignore[arg-type]
+ default=None,
+)
+
+BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
+ BaseTy.int: longT,
+ BaseTy.float: doubleT,
+ BaseTy.bool: boolT,
+ BaseTy.str: stringT,
+ BaseTy.ScalarType: scalarTypeT,
+ BaseTy.Tensor: tensorT,
+ BaseTy.Scalar: scalarT,
+ BaseTy.MemoryFormat: memoryFormatT,
+}
+
+
+@dataclass(frozen=True)
+class OptionalCType(CType):
+ elem: "CType"
+
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
+ # Do not pass `strip_ref` recursively.
+ return f"torch::executor::optional<{self.elem.cpp_type()}>"
+
+ def cpp_type_registration_declarations(self) -> str:
+ return f"torch::executor::optional<{self.elem.cpp_type_registration_declarations()}>"
+
+ def remove_const_ref(self) -> "CType":
+ return OptionalCType(self.elem.remove_const_ref())
+
+
+@dataclass(frozen=True)
+class ArrayRefCType(CType):
+ elem: "CType"
+
+ def cpp_type(self, *, strip_ref: bool = False) -> str:
+ # Do not pass `strip_ref` recursively.
+ return f"torch::executor::ArrayRef<{self.elem.cpp_type()}>"
+
+ def cpp_type_registration_declarations(self) -> str:
+ return f"torch::executor::ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
+
+ def remove_const_ref(self) -> "CType":
+ return ArrayRefCType(self.elem.remove_const_ref())
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/api/unboxing.py b/MLPY/Lib/site-packages/torchgen/executorch/api/unboxing.py
new file mode 100644
index 0000000000000000000000000000000000000000..9df3a929c1fdeab35f50ec9ab661c1ff9ad6c3af
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/executorch/api/unboxing.py
@@ -0,0 +1,213 @@
+from dataclasses import dataclass
+from typing import Callable, List, Sequence, Tuple
+
+from torchgen.api.types import Binding, CType, NamedCType
+from torchgen.model import (
+ Argument,
+ BaseTy,
+ BaseType,
+ ListType,
+ NativeFunction,
+ OptionalType,
+ Type,
+)
+
+connector = "\n\t"
+
+
+# Return unboxing function name for a NativeFunction
+def name(f: NativeFunction) -> str:
+ return f.func.name.unambiguous_name()
+
+
+@dataclass(frozen=True)
+class Unboxing:
+ """
+ Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
+ A sample generated code:
+ // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+ void mul_out(EValue** stack) {
+ EValue& self = *stack[0];
+ EValue& other = *stack[1];
+ EValue& out = *stack[2];
+ const torch::executor::Tensor & self_base = self.to();
+ const torch::executor::Tensor & other_base = other.to();
+ torch::executor::Tensor & out_base = out.to();
+
+ EXECUTORCH_SCOPE_PROF("native_call_mul.out");
+ torch::executor::mul_outf(self_base, other_base, out_base);
+
+
+ }
+ """
+
+ # this is a callable that converts a JIT argument, into its C++ type.
+ # Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type.
+ argument_type_gen: Callable[
+ ...,
+ NamedCType,
+ ]
+
+ # Convert all the arguments in a NativeFunction to C++ code
+ def convert_arguments(
+ self, args: Sequence[Binding]
+ ) -> Tuple[List[Binding], List[str]]:
+ code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
+ binding_list = []
+ for arg in args:
+ # expecting only Argument
+ if not isinstance(arg.argument, Argument):
+ raise Exception(
+ f"Unexpected argument type, expecting `Argument` but got {arg}"
+ )
+ argument: Argument = arg.argument
+ unboxed_name, _, code, decl = self.argumenttype_evalue_convert(
+ argument.type, argument.name, mutable=argument.is_write
+ )
+ code_list.extend(decl)
+ code_list.extend(code)
+ binding_list.append(arg.with_name(unboxed_name))
+ return binding_list, code_list
+
+ def argumenttype_evalue_convert(
+ self, t: Type, arg_name: str, *, mutable: bool = False
+ ) -> Tuple[str, CType, List[str], List[str]]:
+ """
+ Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
+ (1) the C++ code necessary to unbox the argument
+ (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
+ :param t: a `Type` of an argument
+ :param arg_name: argument name
+ :param mutable: boolean for whether this argument type is mutable
+ :return: unboxed result
+ """
+ ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type
+
+ if isinstance(t, BaseType):
+ out_name = f"{arg_name}_base"
+ code, decl = self._gen_code_base_type(
+ arg_name=arg_name, out_name=out_name, ctype=ctype
+ )
+ elif isinstance(t, OptionalType):
+ out_name = f"{arg_name}_opt_out"
+ code, decl = self._gen_code_optional_type(
+ arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
+ )
+ elif isinstance(t, ListType):
+ out_name = f"{arg_name}_list_out"
+ code, decl = self._gen_code_list_type(
+ arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
+ )
+ else:
+ raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}")
+ return out_name, ctype, code, decl
+
+ def _gen_code_base_type(
+ self, arg_name: str, out_name: str, ctype: CType
+ ) -> Tuple[List[str], List[str]]:
+ return [
+ f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
+ ], []
+
+ def _gen_code_optional_type(
+ self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
+ ) -> Tuple[List[str], List[str]]:
+ in_name = f"{arg_name}_opt_in"
+ res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
+ t.elem, in_name
+ )
+ return (
+ f"""
+ {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
+ """.split(
+ "\n"
+ ),
+ decl,
+ )
+
+ def _gen_code_list_type(
+ self, arg_name: str, out_name: str, t: ListType, ctype: CType
+ ) -> Tuple[List[str], List[str]]:
+ in_name = f"{arg_name}_list_in"
+ elem_name = f"{arg_name}_elem"
+ code = []
+ res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert(
+ t.elem, elem_name
+ )
+
+ if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
+ code.extend(
+ f"""
+ {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toTensorList();
+ """.split(
+ "\n"
+ )
+ )
+ elif isinstance(t.elem, BaseType) and (
+ t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
+ ):
+ code.extend(
+ f"""
+ {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toIntList();
+ """.split(
+ "\n"
+ )
+ )
+ elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
+ code.extend(
+ f"""
+ {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toDoubleList();
+ """.split(
+ "\n"
+ )
+ )
+ elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
+ # handle list type with size, e.g., bool[4]
+ code.extend(
+ f"""
+ {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toBoolList();
+ """.split(
+ "\n"
+ )
+ )
+ # pytorch codegen:
+ # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List>
+ elif (
+ isinstance(t.elem, OptionalType)
+ and isinstance(t.elem.elem, BaseType)
+ and t.elem.elem.name == BaseTy.Tensor
+ ):
+ code.extend(
+ f"""
+#ifdef USE_ATEN_LIB
+at::ArrayRef> {in_name} = {arg_name}.toListOptionalTensor();
+c10::List> {out_name};
+for (auto {elem_name}: {in_name}) {{
+ {out_name}.push_back({elem_name});
+}}
+#else
+torch::executor::ArrayRef> {out_name} = {arg_name}.toListOptionalTensor();
+#endif
+ """.split(
+ "\n"
+ )
+ )
+ else:
+ # use ArrayRef as default.
+ vec_name = arg_name + "_vec"
+ # need to bring vector instantiation out of scope so that ArrayRef has valid data
+ decl.append(
+ f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
+ )
+ code.extend(
+ f"""
+ for (EValue {elem_name}: {in_name}) {{
+ {connector.join(res_code)}
+ {vec_name}.push_back({res_name});
+ }}
+ {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
+ """.split(
+ "\n"
+ )
+ )
+ return code, decl
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/model.py b/MLPY/Lib/site-packages/torchgen/executorch/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..48384e687403f4b04dbaa8a4ecfc97c4934606d0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/executorch/model.py
@@ -0,0 +1,220 @@
+# Represents all kernels used by an Executorch model.
+# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure.
+
+import itertools
+from collections import defaultdict, namedtuple
+from dataclasses import dataclass
+from enum import IntEnum
+from typing import Dict, List, Tuple, Union
+
+from torchgen.model import (
+ BackendIndex,
+ BackendMetadata,
+ DispatchKey,
+ NativeFunction,
+ NativeFunctionsGroup,
+ OperatorName,
+)
+from torchgen.utils import assert_never
+
+KERNEL_KEY_VERSION = 1
+
+
+# TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen
+class ScalarType(IntEnum):
+ Byte = 0
+ Char = 1
+ Short = 2
+ Int = 3
+ Long = 4
+ Float = 6
+ Double = 7
+ Bool = 11
+
+
+ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"])
+
+
+@dataclass(frozen=True)
+class ETKernelKeyOpArgMeta:
+ arg_name: str
+ dtype: str
+ # The order of the dimensions if entry is a Tensor
+ dim_order: Tuple[int, ...]
+
+ def to_native_string(self) -> str:
+ dtype_str = ScalarType[self.dtype].value
+ dim_str = str(self.dim_order)[1:-1].replace(" ", "")
+ return f"{dtype_str};{dim_str}"
+
+
+@dataclass(frozen=True)
+class ETKernelKey:
+ # Field undefined is default = True
+ arg_meta: Tuple[ETKernelKeyOpArgMeta, ...] = ()
+
+ # Indicator for this kernel being used as a catch all
+ default: bool = False
+
+ version: int = KERNEL_KEY_VERSION
+
+ @staticmethod
+ def gen_from_yaml(
+ args: Dict[str, Tuple[str, str]],
+ type_alias_map: Dict[str, List[str]], # TODO: Support unwrapped str val
+ dim_order_alias_map: Dict[str, List[int]],
+ ) -> List["ETKernelKey"]:
+ """Generate ETKernelKeys from arg kernel specs
+ Multiple ETKernelKeys are returned due to dtype permutations from utilizing
+ type_alias_map (actualizing each potential type permutation as a KernelKey)
+
+ Args:
+ args: Mapping from argument name to kernel specs
+ Kernel specs are a tuple of (dtype, dim_order).
+ Currently tuple entries must be aliased via the alias map arguments
+ type_alias_map: Mapping from type alias to potential type enums
+ i.e { T0 : [Double, Int] } means T0 can be either Double or Int
+ Used for lookup by args
+ dim_order_alias_map: Mapping from alias to a list of dimension orders
+ Used for lookup by args
+ """
+ # Cast to dim order to int
+ dim_order_alias_map = {
+ k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items()
+ }
+ kernel_keys = []
+
+ # Get all used Dtype Alias
+ dtype_alias_used = set()
+ for type_alias, dim_order in args.values():
+ # Enforce usage of alias initially
+ # TODO: Support inlined arguments
+ assert type_alias in type_alias_map, "Undefined type alias: " + str(
+ type_alias
+ )
+ assert (
+ dim_order in dim_order_alias_map
+ ), "Undefined dim_order alias: " + str(dim_order)
+ dtype_alias_used.add(type_alias)
+
+ # Generate all permutations of dtype alias values
+ alias_dtypes = [
+ [(alias, dtype) for dtype in type_alias_map[alias]]
+ for alias in dtype_alias_used
+ ]
+ alias_permutations = [
+ dict(permutation) for permutation in list(itertools.product(*alias_dtypes))
+ ]
+
+ # Using each alias value permutation, generate kernel keys
+ op_arg_cache = {}
+ for permutation in alias_permutations:
+ arg_list = []
+ for arg_name, arg_spec in args.items():
+ dtype = permutation[arg_spec[0]]
+ dim_order = dim_order_alias_map[arg_spec[1]] # type: ignore[assignment]
+ if (
+ cache_key := (arg_name, dtype, tuple(dim_order))
+ ) not in op_arg_cache:
+ op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key) # type: ignore[arg-type]
+
+ arg_list.append(op_arg_cache[cache_key])
+ kernel_keys.append(ETKernelKey(tuple(arg_list)))
+
+ return kernel_keys
+
+ def to_native_string(self) -> str:
+ if self.default:
+ return "default"
+ return (
+ "v"
+ + str(KERNEL_KEY_VERSION)
+ + "/"
+ + "|".join([arg.to_native_string() for arg in self.arg_meta])
+ )
+
+
+@dataclass(frozen=True)
+class ETKernelIndex:
+ index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]]
+
+ def has_kernels(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool:
+ m = self.get_kernels(g)
+ return m is not None
+
+ def get_kernels(
+ self, g: Union[NativeFunction, NativeFunctionsGroup]
+ ) -> Dict[ETKernelKey, BackendMetadata]:
+ if isinstance(g, NativeFunction):
+ f = g
+ elif isinstance(g, NativeFunctionsGroup):
+ f = g.functional
+ else:
+ assert_never(g)
+ if f.func.name not in self.index:
+ return {}
+ return self.index[f.func.name]
+
+ @staticmethod
+ def grow_from_backend_indices(
+ kernel_index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]],
+ backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]],
+ ) -> None:
+ for dk in backend_indices:
+ index = backend_indices[dk]
+ for op, backend_metadata in index.items():
+ if op in kernel_index:
+ kernel_index[op][ETKernelKey(default=True)] = backend_metadata
+ else:
+ kernel_index[op] = {ETKernelKey(default=True): backend_metadata}
+
+ @staticmethod
+ def from_backend_indices(
+ backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
+ ) -> "ETKernelIndex":
+ kernel_index: Dict[
+ OperatorName, Dict[ETKernelKey, BackendMetadata]
+ ] = defaultdict(dict)
+ ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
+ return ETKernelIndex(kernel_index)
+
+ def grow(
+ self, backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
+ ) -> "ETKernelIndex":
+ ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
+ return self
+
+ def _to_backend_index(self) -> BackendIndex:
+ """
+ WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
+ """
+ index: Dict[OperatorName, BackendMetadata] = {}
+ for op in self.index:
+ kernel_dict = self.index[op]
+ assert (
+ len(kernel_dict.values()) == 1
+ ), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}"
+ index[op] = kernel_dict.get(
+ ETKernelKey(default=True),
+ BackendMetadata(kernel="", structured=False, cpp_namespace=""),
+ )
+ return BackendIndex(
+ dispatch_key=DispatchKey.CPU,
+ use_out_as_primary=False,
+ device_guard=False,
+ external=False,
+ index=index,
+ )
+
+ # Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
+ @staticmethod
+ def merge_indices(
+ index_a: "ETKernelIndex", index_b: "ETKernelIndex"
+ ) -> "ETKernelIndex":
+ combined = defaultdict(dict, index_a.index.copy())
+
+ for op, entry in index_b.index.items():
+ for key, metadata in entry.items():
+ combined[op][key] = metadata
+
+ return ETKernelIndex(combined)
diff --git a/MLPY/Lib/site-packages/torchgen/executorch/parse.py b/MLPY/Lib/site-packages/torchgen/executorch/parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..275fb339a84972613bbac0988fc2d9116552c675
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/executorch/parse.py
@@ -0,0 +1,151 @@
+from collections import defaultdict, namedtuple
+from typing import Any, Dict, List, Optional, Set, Tuple
+
+import yaml
+
+from torchgen.executorch.model import ETKernelIndex, ETKernelKey
+
+from torchgen.gen import LineLoader, parse_native_yaml
+from torchgen.model import (
+ BackendMetadata,
+ DispatchKey,
+ FunctionSchema,
+ NativeFunction,
+ OperatorName,
+)
+from torchgen.utils import NamespaceHelper
+
+# Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices.
+ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"])
+
+# Fields in native_functions.yaml used to determine which kernels should be used
+ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"]
+
+
+def parse_from_yaml(ei: Dict[str, object]) -> Dict[ETKernelKey, BackendMetadata]:
+ """Given a loaded yaml representing kernel assignment information, extract the
+ mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance)
+
+ Args:
+ ei: Dict keys {kernels, type_alias, dim_order_alias}
+ See ETKernelKey for description of arguments
+ """
+ e = ei.copy()
+ if (kernels := e.pop("kernels", None)) is None:
+ return {}
+
+ type_alias: Dict[str, List[str]] = e.pop("type_alias", {}) # type: ignore[assignment]
+ dim_order_alias: Dict[str, List[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment]
+ dim_order_alias.pop("__line__", None)
+
+ kernel_mapping: Dict[ETKernelKey, BackendMetadata] = {}
+
+ for entry in kernels: # type: ignore[attr-defined]
+ arg_meta = entry.get("arg_meta")
+ if arg_meta is not None:
+ arg_meta.pop("__line__")
+
+ kernel_name = entry.get("kernel_name")
+ namespace_helper = NamespaceHelper.from_namespaced_entity(
+ kernel_name, max_level=3
+ )
+ kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
+ backend_metadata = BackendMetadata(
+ kernel=namespace_helper.entity_name,
+ structured=False,
+ cpp_namespace=(kernel_namespace + "::native"),
+ )
+
+ kernel_keys = (
+ [ETKernelKey((), default=True)]
+ if arg_meta is None
+ else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias) # type: ignore[arg-type]
+ )
+
+ for kernel_key in kernel_keys:
+ assert kernel_key not in kernel_mapping, (
+ "Duplicate kernel key: " + str(kernel_key) + " " + str(e)
+ )
+ kernel_mapping[kernel_key] = backend_metadata
+
+ return kernel_mapping
+
+
+def parse_et_yaml_struct(es: object) -> ETKernelIndex:
+ """Given a loaded yaml representing a list of operators, for each op extract the mapping
+ of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance
+ that should be used by the kernel key).
+ """
+ indices: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] = {}
+ for ei in es: # type: ignore[attr-defined]
+ e = ei.copy()
+
+ funcs = e.pop("func")
+ assert isinstance(funcs, str), f"not a str: {funcs}"
+ namespace_helper = NamespaceHelper.from_namespaced_entity(
+ namespaced_entity=funcs, max_level=1
+ )
+ opname = FunctionSchema.parse(namespace_helper.entity_name).name
+
+ assert opname not in indices, f"Duplicate func found in yaml: {opname} already"
+
+ if len(index := parse_from_yaml(e)) != 0:
+ indices[opname] = index
+
+ return ETKernelIndex(indices)
+
+
+def extract_kernel_fields(es: object) -> Dict[OperatorName, Dict[str, Any]]:
+ """Given a loaded yaml representing a list of operators, extract the
+ kernel key related fields indexed by the operator name.
+ """
+ fields: Dict[OperatorName, Dict[str, Any]] = defaultdict(dict)
+ for ei in es: # type: ignore[attr-defined]
+ funcs = ei.get("func")
+ assert isinstance(funcs, str), f"not a str: {funcs}"
+ namespace_helper = NamespaceHelper.from_namespaced_entity(
+ namespaced_entity=funcs, max_level=1
+ )
+ opname = FunctionSchema.parse(namespace_helper.entity_name).name
+
+ for field in ET_FIELDS:
+ if (value := ei.get(field)) is not None:
+ fields[opname][field] = value
+
+ return fields
+
+
+def parse_et_yaml(
+ path: str,
+ tags_yaml_path: str,
+ ignore_keys: Optional[Set[DispatchKey]] = None,
+ skip_native_fns_gen: bool = False,
+) -> Tuple[List[NativeFunction], Dict[OperatorName, Dict[str, Any]]]:
+ """Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
+ of fields to persist from native_functions.yaml to functions.yaml
+ """
+ with open(path) as f:
+ es = yaml.load(f, Loader=LineLoader)
+
+ et_kernel = extract_kernel_fields(es)
+
+ # Remove ET specific fields from entries for BC compatibility
+ strip_et_fields(es)
+
+ native_yaml = parse_native_yaml(
+ path,
+ tags_yaml_path,
+ ignore_keys,
+ skip_native_fns_gen=skip_native_fns_gen,
+ loaded_yaml=es,
+ )
+ return native_yaml.native_functions, et_kernel
+
+
+def strip_et_fields(es: object) -> None:
+ """Given a loaded yaml representing a list of operators,
+ remove ET specific fields from every entries for BC compatibility
+ """
+ for entry in es: # type: ignore[attr-defined]
+ for field in ET_FIELDS:
+ entry.pop(field, None)
diff --git a/MLPY/Lib/site-packages/torchgen/gen.py b/MLPY/Lib/site-packages/torchgen/gen.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e39839c020d04948fe9758bf3200c6230766831
--- /dev/null
+++ b/MLPY/Lib/site-packages/torchgen/gen.py
@@ -0,0 +1,2937 @@
+import argparse
+import functools
+import json
+import os
+import pathlib
+from collections import defaultdict, namedtuple, OrderedDict
+from dataclasses import dataclass, field
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ TypeVar,
+ Union,
+)
+
+import yaml
+
+import torchgen.api.dispatcher as dispatcher
+import torchgen.api.meta as meta
+import torchgen.api.native as native
+import torchgen.api.structured as structured
+import torchgen.dest as dest
+
+from torchgen.api import cpp
+from torchgen.api.translate import translate
+from torchgen.api.types import (
+ Binding,
+ CppSignature,
+ CppSignatureGroup,
+ DispatcherSignature,
+ NamedCType,
+ NativeSignature,
+ SpecialArgName,
+)
+from torchgen.context import (
+ method_with_native_function,
+ native_function_manager,
+ with_native_function,
+ with_native_function_and_indices,
+)
+from torchgen.gen_aoti_c_shim import (
+ gen_aoti_c_shim,
+ gen_static_dispatch_backend_call_signature,
+ get_backend_index_for_aoti,
+)
+from torchgen.gen_functionalization_type import (
+ gen_functionalization_definition,
+ gen_functionalization_registration,
+ gen_functionalization_view_inverse_declaration,
+ GenCompositeViewCopyKernel,
+)
+from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
+
+from torchgen.model import (
+ Argument,
+ BackendIndex,
+ BackendMetadata,
+ BaseOperatorName,
+ DEFAULT_KERNEL_NAMESPACE,
+ DispatchKey,
+ FRAGMENT_NAMESPACES,
+ FunctionSchema,
+ is_cuda_dispatch_key,
+ is_generic_dispatch_key,
+ is_ufunc_dispatch_key,
+ Location,
+ NativeFunction,
+ NativeFunctionsGroup,
+ NativeFunctionsViewGroup,
+ OperatorName,
+ OptionalType,
+ SchemaKind,
+ SelfArgument,
+ STRUCTURED_DISPATCH_KEYS,
+ TensorOptionsArguments,
+ Type,
+ Variant,
+ ViewSchemaKind,
+)
+from torchgen.native_function_generation import (
+ add_generated_native_functions,
+ gen_composite_functional_kernel,
+ gen_composite_out_kernel,
+ pre_group_native_functions,
+)
+from torchgen.selective_build.selector import SelectiveBuilder
+from torchgen.utils import (
+ assert_never,
+ concatMap,
+ context,
+ FileManager,
+ make_file_manager,
+ mapMaybe,
+ NamespaceHelper,
+ Target,
+)
+from torchgen.yaml_utils import YamlDumper, YamlLoader
+
+T = TypeVar("T")
+
+# Welcome to the ATen code generator v2! The ATen code generator is
+# responsible for parsing native_functions.yaml and then generating
+# various generated files (e.g., TypeDefault.cpp) based on the operators
+# defined in this file. This means that the code generator knows how to
+# parse function schema, and then translate this into various C++ types
+# and boilerplate code.
+#
+# Some things to know about this file when you modify it:
+#
+# - This file has STRICT mypy typechecking. Typecheck it with
+# `mypy --config mypy-strict.ini` in the root source directory
+#
+# - Most of the heavy lifting lives in external modules:
+# - 'model' has the data model for native_functions.yaml. The classes
+# in those file represent what you see when you look at
+# a native_functions.yaml
+# - 'api' has conversions for how to translate JIT schema into
+# the various C++ APIs that the codegen interacts with. There
+# are in fact THREE different C++ APIs: the public C++ API,
+# the dispatcher API, and the legacy dispatcher API. See each
+# of these respective files for more information
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# HELPER FUNCTIONS
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+# A custom loader for YAML to let us also keep track of line numbers
+# of each entry in the YAML file
+class LineLoader(YamlLoader):
+ def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
+ mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
+ # Add 1 so line numbering starts at 1
+ mapping["__line__"] = node.start_mark.line + 1
+ return mapping
+
+
+# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
+ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
+
+
+_GLOBAL_PARSE_NATIVE_YAML_CACHE: Dict[str, ParsedYaml] = {}
+_GLOBAL_PARSE_TAGS_YAML_CACHE: Dict[str, Set[str]] = {}
+
+
+def parse_native_yaml_struct(
+ es: object,
+ valid_tags: Set[str],
+ ignore_keys: Optional[Set[DispatchKey]] = None,
+ path: str = "",
+ skip_native_fns_gen: bool = False,
+) -> ParsedYaml:
+ assert isinstance(es, list)
+ rs: List[NativeFunction] = []
+ bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict)
+ for e in es:
+ assert isinstance(e.get("__line__"), int), e
+ loc = Location(path, e["__line__"])
+ funcs = e.get("func")
+ with context(lambda: f"in {loc}:\n {funcs}"):
+ func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys)
+ rs.append(func)
+ BackendIndex.grow_index(bs, m)
+ error_check_native_functions(rs)
+ # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
+ indices: Dict[DispatchKey, BackendIndex] = defaultdict(
+ lambda: BackendIndex(
+ dispatch_key=DispatchKey.Undefined,
+ use_out_as_primary=True,
+ external=False,
+ device_guard=False,
+ # I'm actually not sure about this; undefined could be hit on
+ # empty TensorList, hypothetically that could have sizes in it
+ index={},
+ )
+ )
+ if not skip_native_fns_gen:
+ add_generated_native_functions(rs, bs)
+ for k, v in bs.items():
+ # All structured in-tree operators are implemented in terms of their out operator.
+ indices[k] = BackendIndex(
+ dispatch_key=k,
+ use_out_as_primary=True,
+ external=False,
+ # Only cuda-like devices in tree require device guards
+ device_guard=is_cuda_dispatch_key(k),
+ index=v,
+ )
+ return ParsedYaml(rs, indices)
+
+
+def parse_tags_yaml_struct(es: object, path: str = "") -> Set[str]:
+ assert isinstance(es, list)
+ rs: Set[str] = set()
+ for e in es:
+ assert isinstance(e.get("__line__"), int), e
+ loc = Location(path, e["__line__"])
+ tags = e.get("tag")
+ with context(lambda: f"in {loc}:\n {tags}"):
+ e_i = e.copy()
+ name = e_i.pop("tag")
+ desc = e_i.pop("desc", "")
+ # ensure that each tag has a non-empty description
+ assert desc != ""
+ rs.add(name)
+ return rs
+
+
+@functools.lru_cache(maxsize=None)
+def parse_tags_yaml(path: str) -> Set[str]:
+ global _GLOBAL_PARSE_TAGS_YAML_CACHE
+ if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
+ with open(path) as f:
+ es = yaml.load(f, Loader=LineLoader)
+ _GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)
+
+ return _GLOBAL_PARSE_TAGS_YAML_CACHE[path]
+
+
+def parse_native_yaml(
+ path: str,
+ tags_yaml_path: str,
+ ignore_keys: Optional[Set[DispatchKey]] = None,
+ *,
+ skip_native_fns_gen: bool = False,
+ loaded_yaml: Optional[object] = None,
+) -> ParsedYaml:
+ global _GLOBAL_PARSE_NATIVE_YAML_CACHE
+ if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
+ valid_tags = parse_tags_yaml(tags_yaml_path)
+
+ # if a loaded yaml is provided, use that instead of reading from path
+ if loaded_yaml is None:
+ with open(path) as f:
+ es = yaml.load(f, Loader=LineLoader)
+ else:
+ es = loaded_yaml
+
+ _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(
+ es,
+ valid_tags,
+ ignore_keys,
+ path=path,
+ skip_native_fns_gen=skip_native_fns_gen,
+ )
+
+ return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
+
+
+# Some assertions are already performed during parsing, but those are only within a single NativeFunction.
+# Assertions here are meant to be performed across NativeFunctions.
+def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
+ func_map: Dict[OperatorName, NativeFunction] = {}
+ base_func_map: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list)
+ for f in funcs:
+ func_map[f.func.name] = f
+ base_func_map[f.func.name.name].append(f)
+ for f in funcs:
+ if f.structured_delegate is not None:
+ delegate_func = func_map[f.structured_delegate]
+ assert delegate_func.structured, (
+ f"{f.func.name} is marked as a structured_delegate pointing to "
+ f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "
+ f"Consider adding 'structured=True' to the delegated operator"
+ )
+ # See Note [resize_ in Functionalization]
+ # resize_() is technically an inplace view op (and therefore needs the tag),
+ # but it would be overkill to add a true "view" variant of resize.
+ # Instead, resize_() gets special treatment in functionalization,
+ # and we have a resize() op that is non-aliasing + functional.
+ if (
+ "inplace_view" in f.tags
+ and str(f.func.name) != "resize_"
+ and str(f.func.name) != "resize_as_"
+ and str(f.func.name.name) != "set_"
+ ):
+ base_name = f.func.name.name
+ overload_name = f.func.name.overload_name
+ assert base_name.inplace, (
+ f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming "
+ "convention for inplace ops - the codegen expects the base name to have a trailing underscore. "
+ )
+ out_of_place_base_name = BaseOperatorName(
+ base_name.base, False, base_name.dunder_method
+ )
+ assert len(base_func_map[out_of_place_base_name]) > 0, (
+ f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding "
+ f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. "
+ )
+
+
+def cpp_string(s: str) -> str:
+ """Convert a python string into a c++ string literal"""
+ s = s.replace("\\", "\\\\")
+ s = s.replace('"', '\\"')
+ s = s.replace("\a", "\\a")
+ s = s.replace("\b", "\\b")
+ s = s.replace("\f", "\\f")
+ s = s.replace("\n", "\\n")
+ s = s.replace("\v", "\\v")
+ s = s.replace("\t", "\\t")
+ return f'"{s}"'
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# C++ CODE GENERATION
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+# Most functions in this section are curried: they consist of a function
+# that takes some parameters (e.g., what is to be generated) which itself
+# returns a function that actually maps NativeFunction to the code
+# to be generated. This pattern makes it convenient to use map, concatMap
+# and similar functional combinators.
+
+
+def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]:
+ if len(backends) == 0:
+ return []
+ else:
+ return [backend.dispatch_key for backend in backends] + [
+ DispatchKey.CompositeImplicitAutograd,
+ DispatchKey.CompositeImplicitAutogradNestedTensor,
+ DispatchKey.CompositeExplicitAutograd,
+ DispatchKey.CompositeExplicitAutogradNonFunctional,
+ ]
+
+
+def get_static_dispatch_backend(
+ f: NativeFunction, backend_index: BackendIndex
+) -> Optional[DispatchKey]:
+ if f.structured_delegate is not None or backend_index.has_kernel(f):
+ # TODO: for ops with structured_delegate it should check the dispatch table of
+ # the out variant instead. For now, these structured ops all have CPU/CUDA kernels
+ # so we always dispatch to the `backend`, but this could be wrong when we
+ # migrate math/default_backend ops to use structured delegate.
+ return backend_index.dispatch_key
+ elif f.has_composite_explicit_autograd_kernel:
+ return DispatchKey.CompositeExplicitAutograd
+ elif f.has_composite_explicit_autograd_non_functional_kernel:
+ return DispatchKey.CompositeExplicitAutogradNonFunctional
+ elif f.has_composite_implicit_autograd_kernel:
+ return DispatchKey.CompositeImplicitAutograd
+ elif f.has_composite_implicit_autograd_nested_tensor_kernel:
+ return DispatchKey.CompositeImplicitAutogradNestedTensor
+ return None
+
+
+def static_dispatch_ops_header(
+ f: NativeFunction, backend_index: List[BackendIndex]
+) -> Optional[str]:
+ if backend_index is None or f.manual_kernel_registration:
+ return None
+
+ output = []
+ for index in backend_index:
+ dispatch_key = get_static_dispatch_backend(f, index)
+ if dispatch_key is not None:
+ output.append(
+ f"#include "
+ )
+ return "\n".join(output)
+
+
+def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]:
+ return [
+ f"#include "
+ for dispatch_key in static_dispatch_keys(backends)
+ ]
+
+
+# Translates arguments of `sig` to CppSignature bindings.
+# Note that we have a special case for `memory_format` argument and this case is not covered by
+# tools.codegen.api.translate() yet as its application is limited to static dispatch.
+def translate_args(
+ sig: Union[CppSignature, DispatcherSignature],
+ cpp_sig: CppSignature,
+) -> str:
+ # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
+ def add_spl_memory_format_binding(input_bindings: List[Binding]) -> List[Binding]:
+ output_bindings: List[Binding] = []
+ for binding in input_bindings:
+ if binding.name == "memory_format":
+ spl_mem_format_binding = Binding(
+ nctype=NamedCType(
+ SpecialArgName.possibly_redundant_memory_format,
+ binding.nctype.type,
+ ),
+ name=binding.name,
+ default=binding.default,
+ argument=binding.argument,
+ )
+ output_bindings.append(spl_mem_format_binding)
+ else:
+ output_bindings.append(binding)
+ return output_bindings
+
+ src_bindings = list(sig.arguments())
+ goal_bindings = list(cpp_sig.arguments())
+ # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
+ # get memory_format bindings of dispatcher signature to have the same NCType as well
+ for arg in goal_bindings:
+ if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
+ src_bindings = add_spl_memory_format_binding(src_bindings)
+ break
+ exprs = translate(src_bindings, goal_bindings)
+ return ", ".join(a.expr for a in exprs)
+
+
+def generate_static_dispatch_backend_call(
+ sig: Union[CppSignature, DispatcherSignature],
+ f: NativeFunction,
+ backend_index: BackendIndex,
+) -> str:
+ cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
+ name = cpp_sig.name()
+ exprs = translate_args(sig, cpp_sig)
+ backend_metadata = backend_index.get_kernel(f)
+ kernel_ns = (
+ backend_metadata.cpp_namespace
+ if backend_metadata and backend_metadata.cpp_namespace
+ else DEFAULT_KERNEL_NAMESPACE
+ )
+ ns = kernel_ns.replace("::native", "")
+ return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});"
+
+
+def generate_static_dispatch_fallback_call(
+ sig: Union[CppSignature, DispatcherSignature],
+ f: NativeFunction,
+ backend_indices: List[BackendIndex],
+) -> str:
+ cpp_sigs = CppSignatureGroup.from_native_function(
+ f, method=False, fallback_binding=False
+ )
+ if sig.symint and f.func.has_symint():
+ cpp_sig = cpp_sigs.symint_signature
+ else:
+ cpp_sig = cpp_sigs.signature
+ assert cpp_sig is not None
+ name = cpp_sig.name()
+ exprs = translate_args(sig, cpp_sig)
+ ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
+ if f.has_composite_explicit_autograd_kernel:
+ return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
+ elif f.has_composite_explicit_autograd_non_functional_kernel:
+ return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
+ elif f.has_composite_implicit_autograd_kernel:
+ return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
+ elif f.has_composite_implicit_autograd_nested_tensor_kernel:
+ return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
+ else:
+ return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
+{', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
+
+
+def static_dispatch(
+ sig: Union[CppSignature, DispatcherSignature],
+ f: NativeFunction,
+ backend_indices: List[BackendIndex],
+) -> str:
+ """
+ For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
+ backends exsit, fallback to static dispatch by determining dispatch key from inputs.
+ Arguments:
+ sig: A CppSignature or DispatcherSignature for this native function we want to use.
+ f: NativeFunction to generate static dispatch.
+ backend_indices: All available backends.
+ Return:
+ C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);"
+ """
+ if len(backend_indices) == 0 or f.manual_kernel_registration:
+ return ""
+
+ keys = [
+ b
+ for b in backend_indices
+ if b.has_kernel(f)
+ or (
+ f.structured_delegate is not None
+ and b.dispatch_key in STRUCTURED_DISPATCH_KEYS
+ )
+ ]
+ if len(keys) == 1:
+ return generate_static_dispatch_backend_call(sig, f, keys[0])
+ elif len(keys) == 0:
+ return generate_static_dispatch_fallback_call(sig, f, backend_indices)
+
+ native_tensor_args = [
+ a.name
+ for a in sig.arguments()
+ if isinstance(a.argument, SelfArgument)
+ or isinstance(a.argument, Argument)
+ and a.argument.type.is_tensor_like()
+ ]
+ tensor_args = ", ".join(native_tensor_args)
+ tensor_opts = f.func.arguments.tensor_options
+
+ stmts = []
+ subexprs: List[str] = []
+ if tensor_opts is not None:
+ subexprs.append(
+ "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
+ )
+ if tensor_args != "":
+ subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})")
+ stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""")
+ stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);")
+
+ dispatch_code = []
+ for index in keys:
+ dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
+ dispatch_code.append(
+ f"""\t{generate_static_dispatch_backend_call(sig, f, index)};"""
+ )
+
+ fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices)
+ connector = "\n\t\t"
+
+ return f"""
+ {connector.join(stmts)}
+ switch (_dk) {{
+ {connector.join(dispatch_code)}
+ default:
+ {fallback}
+ }}
+ """
+
+
+# Generates RegisterSchema.cpp. Depending on the selector, either
+# all schemas are registered, or only some are (in the case of
+# selective build)
+@dataclass(frozen=True)
+class RegisterSchema:
+ selector: SelectiveBuilder
+ known_tags: Dict[str, int] = field(default_factory=dict)
+
+ @method_with_native_function
+ def __call__(self, f: NativeFunction) -> Optional[str]:
+ if not self.selector.is_native_function_selected(f):
+ return None
+ tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}"
+ if tags == "{}":
+ return f"m.def({cpp_string(str(f.func))}, {{}});\n"
+ maybe_tags = ""
+ if tags not in self.known_tags:
+ idx = len(self.known_tags)
+ self.known_tags[tags] = idx
+ maybe_tags = f"const std::vector tags_{idx} = {tags};\n"
+ return f"{maybe_tags}m.def({cpp_string(str(f.func))}, tags_{self.known_tags[tags]});\n"
+
+
+# Generates Operators.h and Operators.cpp.
+# These provide macros that, given an operator and overload name, allow users
+# to access an "un-overloaded" function version of the operator. This
+# is useful for extension writers who want to (1) want to decltype the operator
+# and (2) don't want to worry about method-only operators.
+@dataclass(frozen=True)
+class ComputeOperators:
+ target: Literal[Target.DECLARATION, Target.DEFINITION]
+ static_dispatch_backend_indices: List[BackendIndex]
+
+ @method_with_native_function
+ def __call__(self, f: NativeFunction) -> str:
+ sig = DispatcherSignature.from_schema(f.func)
+ name = f.func.name.unambiguous_name()
+
+ if self.target is Target.DECLARATION:
+ # Note [The ATen Operators API]
+ # The ATen Operators API lives in the at::_ops namespace, and contains compile-time
+ # metadata about each operator + entry points into the Dispatcher.
+ # The C++ function, method, and redispatch API's are all implemented as wrappers
+ # into various bits of the structs defined here.
+ #
+ # Important characteristics about the Operators API:
+ # (1) It follows the Dispatcher API.
+ # This is kind of necessary to avoid overhead.
+ # For example: if it followed the C++ API, then all of the faithful C++ factory functions
+ # would need to wrap their arguments into TensorOptions only to unwrap them again.
+ # (2) Overload names are disambiguated.
+ # This is helpful for pytorch extenders who would like to decltype() an aten operator,
+ # that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
+ # (3) No argument defaulting is allowed.
+ # This is more of an implementation detail to avoid #include cycles,
+ # since TensorBody.h (which defines the Tensor class) needs to include this file.
+ # (4) manual_cpp_bindings and faithful names are not included in the API.
+ # This applies to stuff like __dispatch__is_complex(), and add_outf().
+ # These aren't "real aten ops", they're just additional functions provided by the C++ API.
+ # They're implemented as wrappers in Functions.h that call into the actual operators
+ # defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
+ # This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
+ return f"""
+struct TORCH_API {name} {{
+ using schema = {sig.type()};
+ using ptr_schema = schema*;
+ // See Note [static constexpr char* members for windows NVCC]
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}")
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
+ static {sig.defn(name="call", is_redispatching_fn=False)};
+ static {sig.defn(name="redispatch", is_redispatching_fn=True)};
+}};"""
+
+ elif self.target is Target.DEFINITION:
+ defns = f"""
+STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}")
+STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}")
+STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})
+
+// aten::{f.func}
+static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
+ return c10::Dispatcher::singleton()
+ .findSchemaOrThrow({name}::name, {name}::overload_name)
+ .typed<{name}::schema>();
+}}
+"""
+ for is_redispatching_fn in [False, True]:
+ if is_redispatching_fn:
+ dispatcher_exprs_str = ", ".join(
+ ["dispatchKeySet"] + [a.name for a in sig.arguments()]
+ )
+ method_base = "redispatch"
+ else:
+ dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
+ method_base = "call"
+
+ dispatcher_call = method_base
+ method_name = f"{name}::{method_base}"
+
+ fn_body = f"""
+ static auto op = create_{name}_typed_handle();
+ return op.{dispatcher_call}({dispatcher_exprs_str});"""
+
+ if (
+ not is_redispatching_fn
+ and len(self.static_dispatch_backend_indices) > 0
+ ):
+ # call() should go through static dispatch
+ fn_body = static_dispatch(
+ sig, f, backend_indices=self.static_dispatch_backend_indices
+ )
+ defns += f"""
+// aten::{f.func}
+{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
+ {fn_body}
+}}
+"""
+ return defns
+ else:
+ assert_never(self.target)
+
+
+# Generates Functions.h, which provides the functional public C++ API,
+# and the scaffolding to call into the dispatcher from these functions.
+@dataclass(frozen=True)
+class ComputeFunction:
+ @method_with_native_function
+ def __call__(self, f: NativeFunction) -> Optional[str]:
+ sig_group = CppSignatureGroup.from_native_function(
+ f, method=False, fallback_binding=f.manual_cpp_binding
+ )
+ has_symint = f.func.has_symint()
+
+ result = ""
+ for sig in sig_group.signatures():
+ # See Note [The ATen Operators API]
+ target_sig = DispatcherSignature.from_schema(f.func)
+ exprs = translate(sig.arguments(), target_sig.arguments())
+ exprs_str = ", ".join([e.expr for e in exprs])
+
+ if sig.symint:
+ intlike_t = "c10::SymInt"
+ else:
+ intlike_t = "int64_t"
+
+ if Variant.function in f.variants:
+ result += f"""
+// aten::{f.func}
+inline {sig.decl()} {{
+ return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
+}}"""
+
+ # The template function can be used from template situations
+ # where you want to switch between the symint or not version
+ # depending on a template argument
+ #
+ # NB: we ALWAYS generate this even for methods. But we put it in
+ # this header so it can take advantage of per-op headers
+ if has_symint:
+ result += f"""
+namespace symint {{
+ template ::value>>
+ {sig.decl(suppress_symint_suffix=True)} {{
+ return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
+ }}
+}}
+"""
+ return result
+
+
+# Generates TensorBody.h. This file provides the object-oriented (method-based)
+# public C++ API, and the scaffolding to call into the dispatcher from these functions.
+@dataclass(frozen=True)
+class ComputeTensorMethod:
+ target: Literal[Target.DECLARATION, Target.DEFINITION]
+ static_dispatch_backend_indices: List[BackendIndex]
+
+ @method_with_native_function
+ def __call__(self, f: NativeFunction) -> Optional[str]:
+ if Variant.method not in f.variants:
+ return None
+
+ assert not f.func.is_out_fn()
+ assert f.func.arguments.self_arg is not None
+
+ sig_group = CppSignatureGroup.from_native_function(
+ f, method=True, fallback_binding=f.manual_cpp_binding
+ )
+
+ if self.target is Target.DECLARATION:
+ result = ""
+ for sig in sig_group.signatures():
+ result += f"{sig.decl()} const;\n"
+ return result
+
+ if self.target is not Target.DEFINITION:
+ assert_never(self.target)
+
+ result = ""
+
+ for sig in sig_group.signatures():
+ target_sig = DispatcherSignature.from_schema(f.func)
+ exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
+ exprs_str = ", ".join([e.expr for e in exprs])
+
+ result += f"""
+// aten::{f.func}
+inline {sig.defn(prefix="Tensor::")} const {{
+ return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
+}}
+"""
+
+ return result
+
+
+# Generates RedispatchFunctions.h.
+# This is similar to the C++ API defined in Functions.h, but provides access
+# to the dispatcher's redispatch API.
+@dataclass(frozen=True)
+class ComputeRedispatchFunction:
+ @method_with_native_function
+ def __call__(self, f: NativeFunction) -> Optional[str]:
+ # We unconditionally generate function variants of the redispatch API.
+ # This is mainly because we can namespace functions separately, but not methods,
+ sig_group = CppSignatureGroup.from_native_function(
+ f, method=False, fallback_binding=f.manual_cpp_binding
+ )
+
+ result = ""
+ for sig in sig_group.signatures():
+ target_sig = DispatcherSignature.from_schema(f.func)
+ exprs = translate(sig.arguments(), target_sig.arguments())
+ exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
+
+ result += f"""
+// aten::{f.func}
+inline {sig.decl(is_redispatching_fn=True)} {{
+ return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
+}}
+"""
+
+ return result
+
+
+# Generates ATenOpList.cpp, a runtime accessible list of all aten
+# operators.
+# TODO: This was historically used to help some JIT interop code
+# figure out whether or not to treat aten namespace'd operators
+# one way or another, we should reevaluate if this is actually needed.
+@with_native_function
+def compute_aten_op(f: NativeFunction) -> str:
+ return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
+
+
+# Generates MetaFunctions.h
+def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]:
+ if not g.structured:
+ return None
+ with native_function_manager(g.out):
+ name = meta.name(g)
+ args = structured.meta_arguments(g)
+ args_str = ", ".join(a.decl() for a in args)
+ parent_class = g.out.structured_inherits
+ if parent_class is None:
+ parent_class = "at::impl::MetaBase"
+ meta_return = "void"
+ precomputed = g.out.precomputed if g.structured else None
+
+ if precomputed:
+ # Generate the template declaration with one bool parameter for each
+ # precomputed element. Each parameter is true if the corresponding (in
+ # terms of position) precomputed element has been set.
+ precomputed_values = [*precomputed.replace.values(), precomputed.add]
+ precomputed_elements = [
+ elem for replace_list in precomputed_values for elem in replace_list
+ ]
+ precomputed_template_parameters = [
+ elem.name.upper() for elem in precomputed_elements
+ ]
+ precomputed_template_params_str = ", ".join(
+ f"bool {param} = false" for param in precomputed_template_parameters
+ )
+ precompute_template_decl = f"template <{precomputed_template_params_str}>"
+
+ # Generate a string containing declarations of all precomputed elements.
+ precomputed_elements_with_cpp_types = [
+ structured.argument_type(elem, binds=elem.name)
+ for elem in precomputed_elements
+ ]
+
+ precomputed_elements_decl = ";\n".join(
+ f"{elem.cpp_type(strip_ref=True)} {elem.name}"
+ for elem in precomputed_elements_with_cpp_types
+ )
+
+ # Generate "setter" methods for each precomputed element. Each method will return
+ # a new instance of precompute_out with the template parameter that corresponds to
+ # the member set by the method to true (to indicate that it has been set).
+ setter_methods = []
+ for i, elem in enumerate(precomputed_elements):
+ # Generate the signature. The return type will be the same
+ # as the type of `this` but with the template parameter
+ # corresponding to the element set by this method set to true.
+ # The assert generated below will ensure that this template
+ # parameter is false on the type of `this`.
+ return_ty_templates = ", ".join(
+ precomputed_template_parameters[:i]
+ + ["true"]
+ + precomputed_template_parameters[i + 1 :]
+ )
+ return_ty = f"precompute_out<{return_ty_templates}>"
+ elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(
+ strip_ref=True
+ )
+ signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
+
+ # Generate an assert which checks that the
+ # template parameter corresponding to the precomputed
+ # element that is set by this method is false on the
+ # class corresponding to the object that `this` points to.
+ # This ensures that each element can be set only once.
+ assert_msg = f'"{precomputed_elements[i].name} already set"'
+ assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
+
+ # Generate the new object construction block. All state
+ # except the element that this method sets is copied from the
+ # object that `this` points to. The value for the element that
+ # the method sets is taken from a method parameter.
+ construction_stmts = []
+ construction_stmts.append(f"{return_ty} ret;")
+
+ for j, elem in enumerate(precomputed_elements):
+ if i == j:
+ construction_stmts.append(f"ret.{elem.name} = value;")
+ else:
+ construction_stmts.append(
+ f"ret.{elem.name} = this->{elem.name};"
+ )
+
+ construction_stmts.append("return ret;")
+ construction_block = "\n".join(construction_stmts)
+
+ setter_methods.append(
+ f"""
+ {signature} {{
+ {assert_stmt}
+ {construction_block}
+ }}
+ """
+ )
+ setter_methods_decl = "\n".join(setter_methods)
+
+ # Meta should return an instance of the struct containing the precomputed elements.
+ meta_return_template_params = ", ".join(
+ ["true"] * len(precomputed_template_parameters)
+ )
+ # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
+ # type (which has a variable number of template parameters).
+ meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
+ meta_return = "meta_return_ty"
+ precomputed_decl = f"""
+ {precompute_template_decl}
+ struct TORCH_API precompute_out {{
+ {setter_methods_decl}
+ {precomputed_elements_decl};
+ }};"""
+ else:
+ meta_return_typedef = ""
+ precomputed_decl = ""
+
+ return f"""\
+struct TORCH_API structured_{name} : public {parent_class} {{
+ {precomputed_decl}
+ {meta_return_typedef}
+ {meta_return} meta({args_str});
+}};
+"""
+
+
+def needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool:
+ name = str(f.func.name.name)
+ if name.endswith("_like") or name.startswith("new_"):
+ return False
+ if f.func.arguments.tensor_options is None:
+ return False
+ return selector.is_native_function_selected(f)
+
+
+# Generates RegisterBackendSelect.cpp, a series of kernels which provide
+# specialized computation of dispatch key for operator signatures which cannot
+# be easily done automatically using templating.
+@dataclass(frozen=True)
+class ComputeBackendSelect:
+ target: Literal[Target.DEFINITION, Target.REGISTRATION]
+
+ # Selector object to determine which operators to generate
+ # registration code for.
+ selector: SelectiveBuilder
+
+ @method_with_native_function
+ def __call__(self, f: NativeFunction) -> Optional[str]:
+ if not needs_backend_select(f, self.selector):
+ return None
+
+ name = native.name(f.func)
+ # BackendSelect can go to Meta, so it must preserve symints
+ native_sig = NativeSignature(f.func, symint=True)
+
+ native_tensor_args = [
+ a
+ for a in native_sig.arguments()
+ if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
+ ]
+
+ dispatcher_sig = DispatcherSignature.from_schema(f.func)
+
+ sig: Union[NativeSignature, DispatcherSignature]
+ sig = dispatcher_sig
+ dispatcher_exprs = dispatcher_sig.exprs()
+ dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
+
+ if self.target is Target.DEFINITION:
+ # I don't think there's actually a good reason to generate
+ # these two cases differently
+ # The first case could probably be improved though- it calls computeDispatchKeySet(),
+ # which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
+ if native_tensor_args:
+ assert f.func.arguments.has_tensor_arg()
+ tensor_args = ", ".join(a.name for a in native_tensor_args)
+ compute_dk = f"""\
+DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
+DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
+DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
+ else:
+ assert not f.func.arguments.has_tensor_arg()
+ compute_dk = (
+ f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
+ )
+ return f"""\
+// aten::{f.func}
+C10_ALWAYS_INLINE
+{sig.defn(name)} {{
+ {compute_dk}
+ return at::_ops::{f.func.name.unambiguous_name()}::redispatch(
+ _dk, {', '.join(a.expr for a in dispatcher_exprs)});
+}}
+"""
+ elif self.target is Target.REGISTRATION:
+ return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
+ else:
+ assert_never(self.target)
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# YAML CODE GENERATION
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+def format_yaml(data: object) -> str:
+ # Ignore alias in Dumper
+ YamlDumper.ignore_aliases = lambda self, data: True # type: ignore[assignment]
+
+ # Support serializing OrderedDict
+ def dict_representer(dumper: Any, data: Any) -> Any:
+ return dumper.represent_dict(data.items())
+
+ YamlDumper.add_representer(OrderedDict, dict_representer) # type: ignore[no-untyped-call]
+ # Some yaml parsers (e.g. Haskell's) don't understand line breaks.
+ # width=1e9 turns off optional line breaks and improves
+ # the portability of the outputted yaml.
+ return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9) # type: ignore[no-any-return, call-overload]
+
+
+# For some reason, some defaults we write to YAML are written as native
+# YAML objects, rather than doing them uniformly as strings. This
+# function detects those cases and converts them into native Python
+# objects.
+def pythonify_default(s: str) -> object:
+ if s == "true":
+ return True
+ elif s == "false":
+ return False
+
+ try:
+ return int(s)
+ except ValueError:
+ try:
+ return float(s)
+ except ValueError:
+ return s
+
+
+# What is a dynamic type? Over time, the semantic meaning of
+# dynamic type has degraded to meaninglessness (in the old days,
+# it captured dtype-ness of types, but that has gone away with
+# the removal of TH). These days, it's mostly the same thing as
+# the C++ API argument type, except that Tensor and Tensor?
+# arguments simply present as Tensor.
+#
+# TODO: Get rid of dynamic_type, after getting tools/autograd
+# to use the new codegen framework
+def dynamic_type(t: Type) -> str:
+ if isinstance(t, OptionalType):
+ return dynamic_type(t.elem)
+ # Note we don't use t.is_tensor_like() here because it would
+ # also include Tensor[]
+ if str(t) == "Tensor":
+ return "at::Tensor"
+ # This is a legacy concept, so never report SymInt
+ return cpp.argumenttype_type(
+ t, mutable=False, binds="__placeholder__", symint=False
+ ).cpp_type()
+
+
+def compute_method_of_yaml(variants: Set[Variant]) -> List[str]:
+ # This is written out explicitly to ensure that Tensor and
+ # namespace are put into the list in the right order
+ method_of = ["Type"]
+ if Variant.method in variants:
+ method_of.append("Tensor")
+ if Variant.function in variants:
+ method_of.append("namespace")
+ return method_of
+
+
+def compute_returns_yaml(
+ f: NativeFunction,
+) -> Tuple[List[Dict[str, str]], Dict[str, str]]:
+ # Note [name and field_name]
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~
+ # To understand name_to_field_name, we must first talk about this
+ # schema:
+ #
+ # lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
+ #
+ # There is something very odd about this schema: it is an out
+ # variant of the function (that is to say, it will convert into
+ # at::lstsq_out() in the C++ API), but the names of the output
+ # return arguments don't match the keyword argument names of
+ # the inputs. It TURNS OUT that in this situation, the historical
+ # Declarations.yaml we want to output is this (abbreviated to
+ # only show relevant fields):
+ #
+ # arguments:
+ # ...
+ # - field_name: solution
+ # name: X
+ # - field_name: QR
+ # name: qr
+ # ...
+ #
+ # returns:
+ # - field_name: solution
+ # name: X
+ # - field_name: QR
+ # name: qr
+ #
+ # The name of the return fields is stored in 'field_name', and the
+ # name of the arguments is stored in 'name'. So when we process
+ # arguments, we need a way to get at the corresponding return. At
+ # the moment, this is most conveniently done by constructing a
+ # mapping from name (the argument concept) to field_name (the
+ # return concept) while processing return arguments, since we don't
+ # directly maintain this correspondence in the modeling of function
+ # schema itself.
+ #
+ # See also https://github.com/pytorch/pytorch/issues/43114
+ name_to_field_name: Dict[str, str] = {}
+
+ # Compute the returns field of the YAML entry
+ names = cpp.return_names(f)
+ returns = []
+ for i, (r, name) in enumerate(zip(f.func.returns, names)):
+ ret = {
+ "dynamic_type": dynamic_type(r.type),
+ "name": name,
+ # legacy, report ints
+ "type": cpp.return_type(r, symint=False).cpp_type(),
+ }
+
+ if r.name:
+ # See Note [name and field_name]
+ ret["field_name"] = r.name
+ if f.func.is_out_fn():
+ name_to_field_name[f.func.arguments.out[i].name] = r.name
+
+ returns.append(ret)
+
+ return returns, name_to_field_name
+
+
+# arguments in yaml roughly corresponds to the public C++ API
+def compute_cpp_argument_yaml(
+ cpp_a: Binding,
+ *,
+ schema_order: bool,
+ kwarg_only_set: Set[str],
+ out_arg_set: Set[str],
+ name_to_field_name: Dict[str, str],
+) -> object:
+ if isinstance(cpp_a.argument, TensorOptionsArguments):
+ arg: Dict[str, object] = {
+ "annotation": None,
+ "dynamic_type": "at::TensorOptions",
+ "is_nullable": False,
+ "name": cpp_a.name,
+ "type": cpp_a.type,
+ "kwarg_only": True,
+ }
+ if cpp_a.default is not None:
+ arg["default"] = cpp_a.default
+ return arg
+ elif isinstance(cpp_a.argument, SelfArgument):
+ raise AssertionError()
+ elif isinstance(cpp_a.argument, Argument):
+ return compute_argument_yaml(
+ cpp_a.argument,
+ schema_order=schema_order,
+ kwarg_only_set=kwarg_only_set,
+ out_arg_set=out_arg_set,
+ name_to_field_name=name_to_field_name,
+ )
+
+
+def compute_argument_yaml(
+ a: Argument,
+ *,
+ schema_order: bool,
+ kwarg_only_set: Set[str],
+ out_arg_set: Set[str],
+ name_to_field_name: Dict[str, str],
+) -> object:
+ arg: Dict[str, object] = {
+ "annotation": str(a.annotation) if a.annotation else None,
+ "dynamic_type": dynamic_type(a.type),
+ "is_nullable": a.type.is_nullable(),
+ "name": a.name,
+ # legacy, report ints
+ "type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(),
+ }
+ if a.default is not None:
+ arg["default"] = pythonify_default(
+ cpp.default_expr(a.default, a.type, symint=False)
+ )
+ if a.name in kwarg_only_set:
+ arg["kwarg_only"] = True
+ if a.name in out_arg_set:
+ arg["output"] = True
+ arg["allocate"] = True
+ # See Note [name and field_name]
+ if a.name in name_to_field_name:
+ arg["field_name"] = name_to_field_name[a.name]
+ # Historically, booleans don't get their size recorded, because it
+ # is already built into the cpp type (e.g., std::array)
+ l = a.type.is_list_like()
+ if l is not None and l.size is not None and str(l.elem) != "bool":
+ arg["size"] = l.size
+ return arg
+
+
+@with_native_function
+def compute_declaration_yaml(f: NativeFunction) -> object:
+ returns, name_to_field_name = compute_returns_yaml(f)
+
+ # These sets are used to conveniently test if an argument is a
+ # kwarg-only or out argument
+ kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only}
+ out_arg_set = {a.name for a in f.func.arguments.out}
+
+ sig_group = CppSignatureGroup.from_native_function(
+ f, method=False, fallback_binding=False
+ )
+ cpp_args = sig_group.signature.arguments()
+ arguments = [
+ compute_cpp_argument_yaml(
+ cpp_a,
+ schema_order=False,
+ kwarg_only_set=kwarg_only_set,
+ out_arg_set=out_arg_set,
+ name_to_field_name=name_to_field_name,
+ )
+ for cpp_a in cpp_args
+ ]
+
+ schema_order_jit_arguments = list(f.func.schema_order_arguments())
+
+ schema_order_arguments = [
+ compute_argument_yaml(
+ a,
+ schema_order=True,
+ kwarg_only_set=kwarg_only_set,
+ out_arg_set=out_arg_set,
+ name_to_field_name=name_to_field_name,
+ )
+ for a in schema_order_jit_arguments
+ ]
+
+ cpp_schema_order_types = [
+ # NB: method here doesn't matter
+ r.type
+ for a in schema_order_jit_arguments
+ for r in cpp.argument(
+ a,
+ method=False,
+ cpp_no_default_args=set(),
+ faithful=False,
+ symint=False,
+ has_tensor_options=False,
+ )
+ ]
+
+ # legacy, report ints
+ cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type()
+ schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
+
+ is_factory_method = (
+ any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args)
+ and Variant.method not in f.variants
+ )
+
+ return OrderedDict(
+ [
+ ("name", cpp.name(f.func)),
+ ("operator_name", str(f.func.name.name)),
+ ("overload_name", str(f.func.name.overload_name)),
+ ("manual_kernel_registration", f.manual_kernel_registration),
+ (
+ "category_override",
+ f.category_override if f.category_override is not None else "",
+ ),
+ ("schema_string", f"aten::{f.func}"),
+ ("arguments", arguments),
+ ("schema_order_cpp_signature", schema_order_cpp_signature),
+ ("schema_order_arguments", schema_order_arguments),
+ ("method_of", compute_method_of_yaml(f.variants)),
+ ("mode", "native"),
+ ("python_module", "" if f.python_module is None else f.python_module),
+ ("returns", returns),
+ ("inplace", f.func.name.name.inplace),
+ ("is_factory_method", is_factory_method),
+ ("abstract", f.is_abstract),
+ ("device_guard", f.device_guard),
+ ("with_gil", False),
+ ("deprecated", False),
+ ("has_math_kernel", f.has_composite_implicit_autograd_kernel),
+ ]
+ )
+
+
+# See Note [Auto generated composite kernels]
+def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
+ return (f.structured or f.structured_delegate is not None) and (
+ f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace
+ )
+
+
+@with_native_function_and_indices
+def compute_registration_declarations(
+ f: NativeFunction, backend_indices: Dict[DispatchKey, BackendIndex]
+) -> str:
+ name = dispatcher.name(f.func)
+ returns_type = dispatcher.returns_type(
+ f.func.returns
+ ).cpp_type_registration_declarations()
+ args = dispatcher.arguments(f.func)
+ args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args)
+ comment_data: Dict[str, str] = {
+ "schema": f"aten::{f.func}",
+ # TODO: What exactly is the semantics of the 'dispatch' field?
+ "dispatch": str(
+ {k for k, v in backend_indices.items() if v.has_kernel(f)}
+ != {DispatchKey.CompositeImplicitAutograd}
+ and {k for k, v in backend_indices.items() if v.has_kernel(f)}
+ != {
+ DispatchKey.CompositeImplicitAutograd,
+ DispatchKey.CompositeImplicitAutogradNestedTensor,
+ }
+ ),
+ "default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
+ }
+ return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
+"""
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# RUN IT ALL
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+def get_custom_build_selector(
+ provided_op_registration_allowlist: Optional[List[str]],
+ op_selection_yaml_path: Optional[str],
+) -> SelectiveBuilder:
+ assert not (
+ provided_op_registration_allowlist is not None
+ and op_selection_yaml_path is not None
+ ), (
+ "Both provided_op_registration_allowlist and "
+ + "op_selection_yaml_path can NOT be provided at the "
+ + "same time."
+ )
+
+ op_registration_allowlist: Optional[Set[str]] = None
+ if provided_op_registration_allowlist is not None:
+ op_registration_allowlist = set(provided_op_registration_allowlist)
+
+ if op_registration_allowlist is not None:
+ selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
+ op_registration_allowlist,
+ True,
+ False,
+ )
+ elif op_selection_yaml_path is not None:
+ selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path)
+ else:
+ selector = SelectiveBuilder.get_nop_selector()
+
+ return selector
+
+
+def get_grouped_by_view_native_functions(
+ native_functions: Sequence[NativeFunction],
+) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]:
+ def maybe_create_view_group(
+ d: Dict[Union[ViewSchemaKind, SchemaKind], NativeFunction]
+ ) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]:
+ funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = []
+ if ViewSchemaKind.aliasing in d:
+ view = d.pop(ViewSchemaKind.aliasing)
+ view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
+ view_copy = d.pop(SchemaKind.functional, None)
+
+ funcs.append(
+ NativeFunctionsViewGroup(
+ view=view,
+ view_copy=view_copy,
+ view_inplace=view_inplace,
+ )
+ )
+ # Take the remaining functions that weren't part of the view group
+ # and emit them separately
+ funcs.extend(d.values())
+ return funcs
+
+ grouped_by_views: Dict[
+ FunctionSchema, Dict[Union[SchemaKind, ViewSchemaKind], NativeFunction]
+ ] = defaultdict(dict)
+ for f in native_functions:
+ schema = f.func.view_signature()
+ view_kind: ViewSchemaKind = f.view_schema_kind
+ # We need to group up ops relevant to the same "view", consisting of:
+ # view op (ViewSchemaKind.aliasing)
+ # view_inplace op (ViewSchemaKind.aliasing_inplace)
+ # view_copy op (SchemaKind.functional)
+ if view_kind == ViewSchemaKind.non_aliasing:
+ kind = f.func.kind()
+ assert kind not in grouped_by_views[schema]
+ grouped_by_views[schema][kind] = f
+ else:
+ assert view_kind not in grouped_by_views[schema]
+ grouped_by_views[schema][view_kind] = f
+
+ return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
+
+
+def get_grouped_native_functions(
+ native_functions: Sequence[NativeFunction],
+) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
+ def flatten_pre_group(
+ d: Dict[SchemaKind, NativeFunction]
+ ) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
+ r = NativeFunctionsGroup.from_dict(d)
+ if r is None:
+ # Invariant: any NativeFunctions that are code-generated
+ # should have been grouped into NativeFunctionsGroup objects
+ assert not any("generated" in f.tags for f in d.values())
+ return list(d.values())
+ else:
+ return [r]
+
+ # TODO: how come ValuesView isn't a Sequence lol
+ pre_grouped_native_functions = pre_group_native_functions(native_functions)
+ return list(
+ concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))
+ )
+
+
+def get_ns_grouped_kernels(
+ *,
+ grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
+ backend_indices: Dict[DispatchKey, BackendIndex],
+ native_function_decl_gen: Callable[
+ [Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str]
+ ] = dest.compute_native_function_declaration,
+) -> Dict[str, List[str]]:
+ ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
+ for f in grouped_native_functions:
+ native_function_namespaces = set()
+ dispatch_keys = set()
+ for dispatch_key, backend_idx in backend_indices.items():
+ backend_metadata = backend_idx.get_kernel(f)
+ if backend_metadata:
+ namespace = backend_metadata.cpp_namespace
+ dispatch_keys.add(dispatch_key)
+ native_function_namespaces.add(namespace)
+ else:
+ namespace = DEFAULT_KERNEL_NAMESPACE
+ assert (
+ len(native_function_namespaces) <= 1
+ ), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
+ ns_grouped_kernels[namespace].extend(
+ native_function_decl_gen(f, backend_idx)
+ )
+ return ns_grouped_kernels
+
+
+def get_native_function_declarations_from_ns_grouped_kernels(
+ *,
+ ns_grouped_kernels: Dict[str, List[str]],
+) -> List[str]:
+ declarations: List[str] = []
+ newline = "\n"
+ for namespace, kernels in ns_grouped_kernels.items():
+ ns_helper = NamespaceHelper(
+ namespace_str=namespace,
+ entity_name="",
+ max_level=4,
+ )
+ # Convert to a set first to remove duplicate kernel names. Backends are
+ # allowed to repeat kernel names; only generate the declaration once!
+ ordered_kernels = list(OrderedDict.fromkeys(kernels))
+ declarations.extend(
+ f"""
+{ns_helper.prologue}
+{newline.join(ordered_kernels)}
+{ns_helper.epilogue}
+ """.split(
+ newline
+ )
+ )
+ return declarations
+
+
+# Return native function declarations grouped by their namespaces.
+def get_native_function_declarations(
+ *,
+ grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
+ backend_indices: Dict[DispatchKey, BackendIndex],
+ native_function_decl_gen: Callable[
+ [Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str]
+ ] = dest.compute_native_function_declaration,
+) -> List[str]:
+ """
+ Generate kernel declarations, in `NativeFunction(s).h`.
+ :param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
+ :param backend_indices: kernel collections grouped by dispatch key.
+ :param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`.
+ :return: a list of string, from the string with all declarations, grouped by namespaces, split by newline.
+ """
+
+ ns_grouped_kernels = get_ns_grouped_kernels(
+ grouped_native_functions=grouped_native_functions,
+ backend_indices=backend_indices,
+ native_function_decl_gen=native_function_decl_gen,
+ )
+ return get_native_function_declarations_from_ns_grouped_kernels(
+ ns_grouped_kernels=ns_grouped_kernels
+ )
+
+
+def get_kernel_namespace(
+ *, f: Union[NativeFunction, NativeFunctionsGroup], backend_idx: BackendIndex
+) -> str:
+ backend_metadata = backend_idx.get_kernel(f)
+ assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
+ f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} "
+ f"with dispatch key {backend_idx.dispatch_key}"
+ f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'."
+ )
+ return (
+ backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE
+ )
+
+
+# Return native function definitions grouped by dispatch key and custom namespace.
+# Used in RegisterDispatchKey.cpp and etc.
+def get_native_function_definitions(
+ *,
+ fm: FileManager,
+ grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
+ dispatch_key: DispatchKey,
+ backend_idx: BackendIndex,
+ selector: SelectiveBuilder,
+ rocm: bool,
+ symint: bool,
+ skip_dispatcher_op_registration: bool,
+ gen_dispatch_helpers: bool,
+) -> List[str]:
+ definitions: List[str] = []
+ ns_definitions: Dict[str, List[str]] = defaultdict(list)
+ anonymous_definitions: Dict[str, List[str]] = defaultdict(list)
+ registrations: Dict[str, Dict[str, List[str]]] = defaultdict(dict)
+ newline = "\n"
+ ns_gen = dest.RegisterDispatchKey(
+ backend_idx,
+ Target.NAMESPACED_DEFINITION,
+ selector,
+ rocm=rocm,
+ symint=symint,
+ class_method_name=None,
+ skip_dispatcher_op_registration=skip_dispatcher_op_registration,
+ )
+ anonymous_gen = dest.RegisterDispatchKey(
+ backend_idx,
+ Target.ANONYMOUS_DEFINITION,
+ selector,
+ rocm=rocm,
+ symint=symint,
+ class_method_name=None,
+ skip_dispatcher_op_registration=skip_dispatcher_op_registration,
+ )
+ reg_gen = dest.RegisterDispatchKey(
+ backend_idx,
+ Target.REGISTRATION,
+ selector,
+ rocm=rocm,
+ symint=symint,
+ class_method_name=None,
+ skip_dispatcher_op_registration=skip_dispatcher_op_registration,
+ )
+ for f in grouped_native_functions:
+ kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
+ "::native", ""
+ )
+
+ ns_definitions[kernel_namespace].extend(
+ ns_gen(f),
+ )
+ anonymous_definitions[kernel_namespace].extend(
+ anonymous_gen(f),
+ )
+ namespace = (
+ f.namespace if isinstance(f, NativeFunction) else f.functional.namespace
+ )
+ if namespace not in registrations[kernel_namespace]:
+ registrations[kernel_namespace] = defaultdict(list)
+ registrations[kernel_namespace][namespace].extend(
+ reg_gen(f),
+ )
+
+ for kernel_namespace in ns_definitions:
+ if len(ns_definitions[kernel_namespace]) == 0:
+ continue
+ ns_helper = NamespaceHelper(namespace_str=kernel_namespace)
+ registration_body = ""
+ for namespace in registrations[kernel_namespace]:
+ if not registrations[kernel_namespace][namespace]:
+ continue
+ registration_body += f"""
+TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
+ {newline.join(registrations[kernel_namespace][namespace])}
+}};"""
+ definitions.extend(
+ fm.substitute_with_template(
+ "RegisterDispatchDefinitions.ini",
+ lambda: {
+ "ns_prologue": ns_helper.prologue,
+ "ns_epilogue": ns_helper.epilogue,
+ "dispatch_helpers": dest.gen_registration_helpers(backend_idx)
+ if gen_dispatch_helpers
+ else [],
+ "dispatch_anonymous_definitions": anonymous_definitions[
+ kernel_namespace
+ ],
+ "static_init_dispatch_registrations": ""
+ if skip_dispatcher_op_registration
+ else registration_body,
+ "deferred_dispatch_registrations": "",
+ "dispatch_namespace": dispatch_key.lower(),
+ "dispatch_namespaced_definitions": ns_definitions[kernel_namespace],
+ },
+ ).split(newline)
+ )
+
+ return definitions
+
+
+# Return native function declarations grouped by dispatch key and custom namespace.
+# Used in CPUFunctions_inl.h and etc.
+def get_namespaced_declaration(
+ *,
+ grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
+ dispatch_key: DispatchKey,
+ backend_idx: BackendIndex,
+ selector: SelectiveBuilder,
+ rocm: bool,
+ symint: bool,
+) -> List[str]:
+ declarations: List[str] = []
+ ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
+ newline = "\n"
+ func = dest.RegisterDispatchKey(
+ backend_idx,
+ Target.NAMESPACED_DECLARATION,
+ selector,
+ rocm=rocm,
+ class_method_name=None,
+ skip_dispatcher_op_registration=False,
+ symint=symint,
+ )
+ for f in grouped_native_functions:
+ namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
+ "native", dispatch_key.lower()
+ )
+
+ ns_grouped_kernels[namespace].extend(
+ func(f),
+ )
+
+ for namespace, kernels in ns_grouped_kernels.items():
+ if len(kernels) == 0:
+ continue
+ ns_helper = NamespaceHelper(
+ namespace_str=namespace, entity_name="", max_level=3
+ )
+ ordered_kernels = list(OrderedDict.fromkeys(kernels))
+ declarations.extend(
+ f"""
+{ns_helper.prologue}
+{newline.join(ordered_kernels)}
+{ns_helper.epilogue}
+ """.split(
+ newline
+ )
+ )
+ return declarations
+
+
+# Return native function schema registration code for aten and other namespaces.
+def get_native_function_schema_registrations(
+ *,
+ native_functions: Sequence[NativeFunction],
+ schema_selector: SelectiveBuilder,
+) -> Tuple[List[str], str]:
+ ns_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
+ for native_function in native_functions:
+ ns_native_functions[native_function.namespace].append(native_function)
+ schema_registrations = ""
+ aten_schema_registrations = []
+ custom_namespace = None
+ for namespace, funcs in ns_native_functions.items():
+ schema_registrations_body = list(
+ mapMaybe(RegisterSchema(schema_selector), funcs)
+ )
+ # NB: we have to separate aten namespace registration from other namespaces,
+ # because in the template we hardcoded an operator for ATen already.
+ if namespace == "aten":
+ aten_schema_registrations = schema_registrations_body
+ else:
+ custom_namespace = namespace
+ tab = "\t"
+ # if the namespace is predefined, we should use define a library fragment
+ # instead of a new library
+ torch_library_macro = (
+ "TORCH_LIBRARY_FRAGMENT"
+ if namespace in FRAGMENT_NAMESPACES
+ else "TORCH_LIBRARY"
+ )
+ schema_registrations += f"""
+{torch_library_macro}({custom_namespace}, m) {{
+ {tab.join(schema_registrations_body)}
+}};"""
+ return (aten_schema_registrations, schema_registrations)
+
+
+def gen_aggregated_headers(
+ *,
+ native_functions: Sequence[NativeFunction],
+ grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
+ structured_native_functions: Sequence[NativeFunctionsGroup],
+ static_dispatch_idx: List[BackendIndex],
+ selector: SelectiveBuilder,
+ backend_indices: Dict[DispatchKey, BackendIndex],
+ cpu_fm: FileManager,
+ cuda_fm: FileManager,
+ functions_keys: Set[DispatchKey],
+ dispatch_keys: Sequence[DispatchKey],
+ rocm: bool,
+) -> None:
+ # Buck doesn't support dynamic output files, so we aggregate all operator
+ # headers into a single file
+ cpu_fm.write(
+ "NativeMetaFunctions.h",
+ lambda: {
+ "NativeMetaFunctions_includes": [],
+ "NativeMetaFunctions_declarations": list(
+ mapMaybe(compute_meta_function_declaration, structured_native_functions)
+ ),
+ },
+ )
+ method_native_functions = [
+ fn for fn in native_functions if Variant.method in fn.variants
+ ]
+ non_method_native_functions = [
+ fn for fn in native_functions if fn not in method_native_functions
+ ]
+ cpu_fm.write(
+ "MethodOperators.h",
+ lambda: {
+ "MethodOperators_includes": [],
+ "MethodOperators_declarations": list(
+ mapMaybe(
+ ComputeOperators(
+ Target.DECLARATION,
+ static_dispatch_backend_indices=static_dispatch_idx,
+ ),
+ method_native_functions,
+ )
+ ),
+ },
+ )
+ cpu_fm.write(
+ "Operators.h",
+ lambda: {
+ "Operators_includes": ["#include "],
+ "Operators_declarations": list(
+ mapMaybe(
+ ComputeOperators(
+ Target.DECLARATION,
+ static_dispatch_backend_indices=static_dispatch_idx,
+ ),
+ non_method_native_functions,
+ )
+ ),
+ },
+ )
+ cpu_fm.write(
+ "Functions.h",
+ lambda: {
+ "static_dispatch_extra_headers": static_dispatch_extra_headers(
+ static_dispatch_idx
+ ),
+ "Functions_includes": ["#include