# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch from .task import Task class MILNCETask(Task): def reshape_subsample(self, sample): if ( hasattr(self.config.dataset, "subsampling") and self.config.dataset.subsampling is not None and self.config.dataset.subsampling > 1 ): for key in sample: if torch.is_tensor(sample[key]): tensor = self.flat_subsample(sample[key]) if key in ["caps", "cmasks"]: size = tensor.size() batch_size = size[0] * size[1] expanded_size = (batch_size,) + size[2:] tensor = tensor.view(expanded_size) sample[key] = tensor return sample