# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from .task import Task | |
class VLMTask(Task): | |
"""A VLM task for reproducibility. | |
the collator split subsamples into two sub-batches. | |
This has should have no logic changes. | |
but changed the randomness in frame masking. | |
""" | |
def flat_subsample(self, tensor): | |
size = tensor.size() | |
if len(size) >= 2: | |
batch_size = size[0] * (size[1] // 2) | |
expanded_size = ( | |
(batch_size, 2) + size[2:] if len(size) > 2 | |
else (batch_size, 2) | |
) | |
tensor = tensor.view(expanded_size) | |
tensor = torch.cat([tensor[:, 0], tensor[:, 1]], dim=0) | |
return tensor | |