Amharic-SER / collator.py
Gizachew's picture
Upload 11 files
566ae0a verified
raw
history blame
1.19 kB
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import torch
import transformers
from transformers import Wav2Vec2Processor, Wav2Vec2FeatureExtractor
@dataclass
class DataCollatorCTCWithPadding:
feature_extractor: Wav2Vec2FeatureExtractor
padding: Union[bool, str] = True
max_length: Optional[int] = None
max_length_labels: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
input_features = [{"input_values": feature["input_values"]} for feature in features]
label_features = [feature["labels"] for feature in features]
d_type = torch.long if isinstance(label_features[0], int) else torch.float
batch = self.feature_extractor.pad(
input_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
batch["labels"] = torch.tensor(label_features, dtype=d_type)
return batch