wav2vec2 / src /model /modules /processor.py
hoang1007
init
5381499
from typing import Tuple
import torch
from torch import nn
class Wav2Vec2Processor(nn.Module):
def __init__(self):
"""
Convert tuple of waveforms whose length is different to a batch.
Args:
waveforms (Tuple[torch.Tensor]): The waveforms. Shape: (batch_size, wave_length).
Returns:
waveforms (torch.Tensor): The batched waveforms. Shape: (batch_size, max_wave_length).
wave_lengths (torch.Tensor): The wave length of each waveform. Shape: (batch_size,).
"""
super().__init__()
def forward(self, waveforms: Tuple[torch.Tensor, ...]):
device = waveforms[0].device
wave_lengths = torch.tensor(
tuple(waveform.size(0) for waveform in waveforms), device=device
)
max_length = wave_lengths.max().item()
padded = []
for waveform in waveforms:
padded.append(
nn.functional.pad(
waveform,
(0, max_length - waveform.size(0)),
mode="constant",
value=0.0,
)
)
batched_waveforms = torch.stack(padded, dim=0)
return batched_waveforms, wave_lengths