Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import numpy as np | |
from torch.utils.data import Dataset | |
class RepeatedDatasetWrapper(Dataset): | |
def __init__(self, original_dataset, num_repeats): | |
""" | |
Dataset wrapper that repeats the original dataset n times. | |
Args: | |
original_dataset (torch.utils.data.Dataset): The original dataset to be repeated. | |
num_repeats (int): The number of times the dataset should be repeated. | |
""" | |
self.original_dataset = original_dataset | |
self.num_repeats = num_repeats | |
def __getitem__(self, index): | |
""" | |
Retrieve the item at the given index. | |
Args: | |
index (int): The index of the item to be retrieved. | |
""" | |
original_index = index % len(self.original_dataset) | |
return self.original_dataset[original_index] | |
def __len__(self): | |
""" | |
Get the length of the dataset after repeating it n times. | |
Returns: | |
int: The length of the dataset. | |
""" | |
return len(self.original_dataset) * self.num_repeats | |
class SubsampleDatasetWrapper(Dataset): | |
def __init__(self, original_dataset, dataset_size, seed=0, return_orig_idx=False): | |
""" | |
Dataset wrapper that randomly subsamples the original dataset. | |
Args: | |
original_dataset (torch.utils.data.Dataset): The original dataset to be subsampled. | |
dataset_size (int): The size of the subsampled dataset. | |
seed (int): The seed to use for selecting the subset of indices of the original dataset. | |
return_orig_idx (bool): Whether to return the original index of the item in the original dataset. | |
""" | |
self.original_dataset = original_dataset | |
self.dataset_size = dataset_size or len(original_dataset) | |
self.return_orig_idx = return_orig_idx | |
np.random.seed(seed) | |
self.indices = np.random.permutation(len(self.original_dataset))[:self.dataset_size] | |
def __getitem__(self, index): | |
""" | |
Retrieve the item at the given index. | |
Args: | |
index (int): The index of the item to be retrieved. | |
""" | |
original_index = self.indices[index] | |
sample = self.original_dataset[original_index] | |
return sample, original_index if self.return_orig_idx else sample | |
def __len__(self): | |
""" | |
Get the length of the dataset after subsampling it. | |
Returns: | |
int: The length of the dataset. | |
""" | |
return len(self.indices) | |