PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
d28af7f verified
raw
history blame
3.87 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from collections import OrderedDict
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate
from ..utils import set_seed
class MMDataset(Dataset):
"""
A generic multi-modal dataset.
Args:
`meta_processor`: a meta processor,
handling loading meta data and return video_id and text_id.
`video_processor`: a video processor,
handling e.g., decoding, loading .np files.
`text_processor`: a text processor,
handling e.g., tokenization.
`aligner`: combine the video and text feature
as one training example.
"""
def __init__(
self,
meta_processor,
video_processor,
text_processor,
align_processor,
):
self.split = meta_processor.split
self.meta_processor = meta_processor
self.video_processor = video_processor
self.text_processor = text_processor
self.align_processor = align_processor
def __len__(self):
return len(self.meta_processor)
def __getitem__(self, idx):
if self.split == "test":
set_seed(idx)
video_id, text_id = self.meta_processor[idx]
video_feature = self.video_processor(video_id)
text_feature = self.text_processor(text_id)
output = self.align_processor(video_id, video_feature, text_feature)
# TODO (huxu): the following is for debug purpose.
output.update({"idx": idx})
return output
def collater(self, samples):
"""This collator is deprecated.
set self.collator = MMDataset.collater.
see collator in FairseqMMDataset.
"""
if len(samples) == 0:
return {}
if isinstance(samples[0], dict):
batch = OrderedDict()
for key in samples[0]:
if samples[0][key] is not None:
batch[key] = default_collate(
[sample[key] for sample in samples])
# if torch.is_tensor(batch[key]):
# print(key, batch[key].size())
# else:
# print(key, len(batch[key]))
return batch
else:
return default_collate(samples)
def print_example(self, output):
print("[one example]", output["video_id"])
if (
hasattr(self.align_processor, "subsampling")
and self.align_processor.subsampling is not None
and self.align_processor.subsampling > 1
):
for key in output:
if torch.is_tensor(output[key]):
output[key] = output[key][0]
# search tokenizer to translate ids back.
tokenizer = None
if hasattr(self.text_processor, "tokenizer"):
tokenizer = self.text_processor.tokenizer
elif hasattr(self.align_processor, "tokenizer"):
tokenizer = self.align_processor.tokenizer
if tokenizer is not None:
caps = output["caps"].tolist()
if isinstance(caps[0], list):
caps = caps[0]
print("caps", tokenizer.decode(caps))
print("caps", tokenizer.convert_ids_to_tokens(caps))
for key, value in output.items():
if torch.is_tensor(value):
if len(value.size()) >= 3: # attention_mask.
print(key, value.size())
print(key, "first", value[0, :, :])
print(key, "last", value[-1, :, :])
else:
print(key, value)
print("[end of one example]")