libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
2.85 kB
import torch
from typing import List, Union
from torch.utils import data
# FIXME Test and fix these split functions later
def random_split(dataset: data.Dataset, lengths, seed):
return data.random_split(dataset, lengths, generator=torch.Generator().manual_seed(seed))
def cold_start(dataset: data.Dataset, frac: List[float], entities: Union[str, List[str]]):
"""Create cold-start splits for PyTorch datasets.
Args:
dataset (Dataset): PyTorch dataset object.
frac (list): A list of train/valid/test fractions.
entities (Union[str, List[str]]): Either a single "cold" entity or a list of "cold" entities
on which the split is done.
Returns:
dict: A dictionary of splitted datasets, where keys are 'train', 'valid', and 'test',
and values correspond to each dataset.
"""
if isinstance(entities, str):
entities = [entities]
train_frac, val_frac, test_frac = frac
# Collect unique instances for each entity
entity_instances = {}
for entity in entities:
entity_instances[entity] = list(set([getattr(sample, entity) for sample in dataset]))
# Sample instances belonging to the test datasets
test_entity_instances = [
torch.randperm(len(entity_instances[entity]))[:int(len(entity_instances[entity]) * test_frac)]
for entity in entities
]
# Select samples where all entities are in the test set
test_indices = []
for i, sample in enumerate(dataset):
if all([getattr(sample, entity) in entity_instances[entity][test_entity_instances[j]] for j, entity in enumerate(entities)]):
test_indices.append(i)
if len(test_indices) == 0:
raise ValueError('No test samples found. Try increasing the test frac or a less stringent splitting strategy.')
# Proceed with validation data
train_val_indices = list(set(range(len(dataset))) - set(test_indices))
val_entity_instances = [
torch.randperm(len(entity_instances[entity]))[:int(len(entity_instances[entity]) * val_frac / (1 - test_frac))]
for entity in entities
]
val_indices = []
for i in train_val_indices:
if all([getattr(dataset[i], entity) in entity_instances[entity][val_entity_instances[j]] for j, entity in enumerate(entities)]):
val_indices.append(i)
if len(val_indices) == 0:
raise ValueError('No validation samples found. Try increasing the test frac or a less stringent splitting strategy.')
train_indices = list(set(train_val_indices) - set(val_indices))
train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)
test_dataset = torch.utils.data.Subset(dataset, test_indices)
return {'train': train_dataset, 'valid': val_dataset, 'test': test_dataset}