tgritsaev's picture
Upload 198 files
affcd23 verified
raw
history blame
1.73 kB
import torch
import logging
from typing import List
logger = logging.getLogger(__name__)
def collate_fn(dataset_items: List[dict]):
"""
Collate and pad fields in dataset items
"""
# TODO: your code here
feature_length_dim = dataset_items[0]["spectrogram"].shape[1]
time_dim = max(dataset_items, key=lambda item: item["spectrogram"].shape[2])["spectrogram"].shape[2]
spectrogram = torch.zeros((len(dataset_items), feature_length_dim, time_dim))
spectrogram_length = []
text_length_dim = max(dataset_items, key=lambda item: item["text_encoded"].shape[1])["text_encoded"].shape[1]
text_encoded = torch.zeros((len(dataset_items), text_length_dim))
text_encoded_length = []
text = []
audio_path = []
audio = []
for i, item in enumerate(dataset_items):
cur_time_dim = item["spectrogram"].shape[2]
spectrogram[i] = torch.cat([item["spectrogram"][0], torch.zeros((feature_length_dim, time_dim - cur_time_dim))], axis=1)
spectrogram_length.append(cur_time_dim)
cur_text_length_dim = item["text_encoded"].shape[1]
text_encoded[i] = torch.cat([item["text_encoded"][0], torch.zeros(text_length_dim - cur_text_length_dim)])
text_encoded_length.append(cur_text_length_dim)
text.append(item["text"])
audio_path.append(item["audio_path"])
audio.append(item["audio"])
return {
"spectrogram": spectrogram,
"spectrogram_length": torch.Tensor(spectrogram_length).to(torch.int32),
"text_encoded": text_encoded,
"text_encoded_length": torch.Tensor(text_encoded_length).to(torch.int32),
"text": text,
"audio_path": audio_path,
"audio": audio,
}