|
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Union
|
|
import torch
|
|
|
|
import transformers
|
|
from transformers import Wav2Vec2Processor, Wav2Vec2FeatureExtractor
|
|
|
|
|
|
@dataclass
|
|
class DataCollatorCTCWithPadding:
|
|
"""
|
|
Data collator that will dynamically pad the inputs received.
|
|
Args:
|
|
feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`)
|
|
The feature_extractor used for proccessing the data.
|
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
|
among:
|
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
|
sequence if provided).
|
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
|
maximum acceptable input length for the model if that argument is not provided.
|
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
|
different lengths).
|
|
max_length (:obj:`int`, `optional`):
|
|
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
|
|
max_length_labels (:obj:`int`, `optional`):
|
|
Maximum length of the ``labels`` returned list and optionally padding length (see above).
|
|
pad_to_multiple_of (:obj:`int`, `optional`):
|
|
If set will pad the sequence to a multiple of the provided value.
|
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
|
7.5 (Volta).
|
|
"""
|
|
|
|
feature_extractor: Wav2Vec2FeatureExtractor
|
|
padding: Union[bool, str] = True
|
|
max_length: Optional[int] = None
|
|
max_length_labels: Optional[int] = None
|
|
pad_to_multiple_of: Optional[int] = None
|
|
pad_to_multiple_of_labels: Optional[int] = None
|
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
|
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
|
label_features = [feature["labels"] for feature in features]
|
|
|
|
d_type = torch.long if isinstance(label_features[0], int) else torch.float
|
|
|
|
batch = self.feature_extractor.pad(
|
|
input_features,
|
|
padding=self.padding,
|
|
max_length=self.max_length,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
batch["labels"] = torch.tensor(label_features, dtype=d_type)
|
|
|
|
return batch |