|
import torch |
|
import logging |
|
from typing import List |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def collate_fn(dataset_items: List[dict]): |
|
""" |
|
Collate and pad fields in dataset items |
|
""" |
|
|
|
feature_length_dim = dataset_items[0]["spectrogram"].shape[1] |
|
time_dim = max(dataset_items, key=lambda item: item["spectrogram"].shape[2])["spectrogram"].shape[2] |
|
spectrogram = torch.zeros((len(dataset_items), feature_length_dim, time_dim)) |
|
spectrogram_length = [] |
|
|
|
text_length_dim = max(dataset_items, key=lambda item: item["text_encoded"].shape[1])["text_encoded"].shape[1] |
|
text_encoded = torch.zeros((len(dataset_items), text_length_dim)) |
|
text_encoded_length = [] |
|
text = [] |
|
|
|
audio_path = [] |
|
audio = [] |
|
for i, item in enumerate(dataset_items): |
|
cur_time_dim = item["spectrogram"].shape[2] |
|
spectrogram[i] = torch.cat([item["spectrogram"][0], torch.zeros((feature_length_dim, time_dim - cur_time_dim))], axis=1) |
|
spectrogram_length.append(cur_time_dim) |
|
|
|
cur_text_length_dim = item["text_encoded"].shape[1] |
|
text_encoded[i] = torch.cat([item["text_encoded"][0], torch.zeros(text_length_dim - cur_text_length_dim)]) |
|
text_encoded_length.append(cur_text_length_dim) |
|
text.append(item["text"]) |
|
|
|
audio_path.append(item["audio_path"]) |
|
audio.append(item["audio"]) |
|
|
|
return { |
|
"spectrogram": spectrogram, |
|
"spectrogram_length": torch.Tensor(spectrogram_length).to(torch.int32), |
|
"text_encoded": text_encoded, |
|
"text_encoded_length": torch.Tensor(text_encoded_length).to(torch.int32), |
|
"text": text, |
|
"audio_path": audio_path, |
|
"audio": audio, |
|
} |
|
|