Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
from typing import List, Union | |
from torch.utils import data | |
# FIXME Test and fix these split functions later | |
def random_split(dataset: data.Dataset, lengths, seed): | |
return data.random_split(dataset, lengths, generator=torch.Generator().manual_seed(seed)) | |
def cold_start(dataset: data.Dataset, frac: List[float], entities: Union[str, List[str]]): | |
"""Create cold-start splits for PyTorch datasets. | |
Args: | |
dataset (Dataset): PyTorch dataset object. | |
frac (list): A list of train/valid/test fractions. | |
entities (Union[str, List[str]]): Either a single "cold" entity or a list of "cold" entities | |
on which the split is done. | |
Returns: | |
dict: A dictionary of splitted datasets, where keys are 'train', 'valid', and 'test', | |
and values correspond to each dataset. | |
""" | |
if isinstance(entities, str): | |
entities = [entities] | |
train_frac, val_frac, test_frac = frac | |
# Collect unique instances for each entity | |
entity_instances = {} | |
for entity in entities: | |
entity_instances[entity] = list(set([getattr(sample, entity) for sample in dataset])) | |
# Sample instances belonging to the test datasets | |
test_entity_instances = [ | |
torch.randperm(len(entity_instances[entity]))[:int(len(entity_instances[entity]) * test_frac)] | |
for entity in entities | |
] | |
# Select samples where all entities are in the test set | |
test_indices = [] | |
for i, sample in enumerate(dataset): | |
if all([getattr(sample, entity) in entity_instances[entity][test_entity_instances[j]] for j, entity in enumerate(entities)]): | |
test_indices.append(i) | |
if len(test_indices) == 0: | |
raise ValueError('No test samples found. Try increasing the test frac or a less stringent splitting strategy.') | |
# Proceed with validation data | |
train_val_indices = list(set(range(len(dataset))) - set(test_indices)) | |
val_entity_instances = [ | |
torch.randperm(len(entity_instances[entity]))[:int(len(entity_instances[entity]) * val_frac / (1 - test_frac))] | |
for entity in entities | |
] | |
val_indices = [] | |
for i in train_val_indices: | |
if all([getattr(dataset[i], entity) in entity_instances[entity][val_entity_instances[j]] for j, entity in enumerate(entities)]): | |
val_indices.append(i) | |
if len(val_indices) == 0: | |
raise ValueError('No validation samples found. Try increasing the test frac or a less stringent splitting strategy.') | |
train_indices = list(set(train_val_indices) - set(val_indices)) | |
train_dataset = torch.utils.data.Subset(dataset, train_indices) | |
val_dataset = torch.utils.data.Subset(dataset, val_indices) | |
test_dataset = torch.utils.data.Subset(dataset, test_indices) | |
return {'train': train_dataset, 'valid': val_dataset, 'test': test_dataset} | |