PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
d28af7f verified
raw
history blame
954 Bytes
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .task import Task
class MILNCETask(Task):
def reshape_subsample(self, sample):
if (
hasattr(self.config.dataset, "subsampling")
and self.config.dataset.subsampling is not None
and self.config.dataset.subsampling > 1
):
for key in sample:
if torch.is_tensor(sample[key]):
tensor = self.flat_subsample(sample[key])
if key in ["caps", "cmasks"]:
size = tensor.size()
batch_size = size[0] * size[1]
expanded_size = (batch_size,) + size[2:]
tensor = tensor.view(expanded_size)
sample[key] = tensor
return sample