Fill-Mask
Transformers
Safetensors
esm
Sophia Vincoff
caid benchmark
bae913a
import pandas as pd
import torch
import pickle
import os
from torch.utils.data import DataLoader, Dataset
from fuson_plm.utils.logging import log_update
# Dataset class that loads embeddings and labels
# Write it to either use a cached location for embeddings, or be able to make them on the spot
def custom_collate_fn(batch):
"""
Custom collate function to handle batches with strings and tensors.
Args:
batch (list): List of tuples returned by __getitem__.
Returns:
tuple: (sequences, embeddings, labels)
- sequences: List of strings
- embeddings: Tensor of shape (batch_size, embedding_dim)
- labels: Tensor of shape (batch_size, sequence_length)
"""
sequences, embeddings, labels = zip(*batch) # Unzip the batch into separate tuples
# Stack embeddings and labels into tensors
embeddings = torch.stack(embeddings, dim=0) # Shape: (batch_size, embedding_dim)
labels = torch.stack(labels, dim=0) # Shape: (batch_size, sequence_length)
# Convert sequences from tuple to list
sequences = list(sequences)
return sequences, embeddings, labels
class DisorderDataset(Dataset):
def __init__(self, csv_file_path, cached_embeddings_path=None, max_length=4405):
super(DisorderDataset, self).__init__()
self.dataset = pd.read_csv(csv_file_path)#.head(5)
self.cached_embeddings_path = cached_embeddings_path
# initialize embeddings
self.embeddings = self.__retrieve_embeddings__()
def __len__(self):
return len(self.dataset)
def __retrieve_embeddings__(self):
try:
with open(self.cached_embeddings_path,"rb") as f:
# Load all embeddings
embeddings = pickle.load(f)
except:
raise Exception("Error: failed to load embeddings")
# Keep only embeddings for the sequences in self.dataset
seqs = self.dataset['Sequence'].tolist()
embeddings = {k:v for k,v in embeddings.items() if k in seqs}
return embeddings
def __getitem__(self, idx):
sequence = self.dataset.iloc[idx]['Sequence']
embedding = self.embeddings[sequence]
embedding = torch.tensor(embedding, dtype=torch.float32)
# Convert string representations of labels to floats
label_str = self.dataset.iloc[idx]['Label']
#label_str = label_str[1:-1] why this line???
labels = list(map(int, label_str))
labels = torch.tensor(labels, dtype=torch.float)
assert len(labels)==len(sequence)
return sequence, embedding, labels
def get_dataloader(data_path, cached_embeddings_path, max_length=4405, batch_size=1, shuffle=True):
"""
Creates a DataLoader for the dataset.
Args:
data_path (str): Path to the CSV file (train, val, or test).
batch_size (int): Batch size.
shuffle (bool): Whether to shuffle the data.
tokenizer (Tokenizer): tokenizer object for data tokenization
Returns:
DataLoader: DataLoader object.
"""
dataset = DisorderDataset(data_path, cached_embeddings_path=cached_embeddings_path, max_length=max_length)
return DataLoader(dataset, batch_size=batch_size, collate_fn=custom_collate_fn, shuffle=shuffle)
def check_dataloaders(train_loader, test_loader, max_length=512, checkpoint_dir=''):
log_update(f'\nBuilt train and test dataloders')
log_update(f"\tNumber of sequences in the Training DataLoader: {len(train_loader.dataset)}")
log_update(f"\tNumber of sequences in the Testing DataLoader: {len(test_loader.dataset)}")
dataloader_overlaps = check_dataloader_overlap(train_loader, test_loader)
if len(dataloader_overlaps)==0: log_update("\tDataloaders are clean (no overlaps)")
else: log_update(f"\tWARNING! sequence overlap found: {','.join(dataloader_overlaps)}")
# write length ranges to a text file
if not(os.path.exists(f'{checkpoint_dir}/batch_diversity')):
os.mkdir(f'{checkpoint_dir}/batch_diversity')
max_length_violators = []
for name, dataloader in {'train':train_loader, 'test':test_loader}.items():
max_length_followed = check_max_length(dataloader, max_length)
if max_length_followed == False:
max_length_violators.append(name)
if len(max_length_violators)==0: log_update(f"\tDataloaders follow the max length limit set by user: {max_length}")
else: log_update(f"\tWARNING! these loaders have sequences longer than max length={max_length}: {','.join(max_length_violators)}")
def check_dataloader_overlap(train_loader, test_loader):
train_seqs = set()
test_seqs = set()
for batch_idx, (sequences, _, _) in enumerate(train_loader):
train_seqs.add(sequences[0])
for batch_idx, (sequences, _, _) in enumerate(test_loader):
test_seqs.add(sequences[0])
return train_seqs.intersection(test_seqs)
def check_max_length(dataloader, max_length):
for batch_idx, (sequences, _, _) in enumerate(dataloader):
if len(sequences[0]) > max_length:
return False
return True