Fill-Mask
Transformers
Safetensors
esm
File size: 5,217 Bytes
bae913a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import pandas as pd
import torch
import pickle
import os
from torch.utils.data import DataLoader, Dataset
from fuson_plm.utils.logging import log_update

# Dataset class that loads embeddings and labels
# Write it to either use a cached location for embeddings, or be able to make them on the spot
def custom_collate_fn(batch):
    """
    Custom collate function to handle batches with strings and tensors.
    
    Args:
        batch (list): List of tuples returned by __getitem__.
        
    Returns:
        tuple: (sequences, embeddings, labels)
            - sequences: List of strings
            - embeddings: Tensor of shape (batch_size, embedding_dim)
            - labels: Tensor of shape (batch_size, sequence_length)
    """
    sequences, embeddings, labels = zip(*batch)  # Unzip the batch into separate tuples

    # Stack embeddings and labels into tensors
    embeddings = torch.stack(embeddings, dim=0)  # Shape: (batch_size, embedding_dim)
    labels = torch.stack(labels, dim=0)          # Shape: (batch_size, sequence_length)

    # Convert sequences from tuple to list
    sequences = list(sequences)
    
    return sequences, embeddings, labels
    
class DisorderDataset(Dataset):
    def __init__(self, csv_file_path, cached_embeddings_path=None, max_length=4405):
        super(DisorderDataset, self).__init__()
        self.dataset = pd.read_csv(csv_file_path)#.head(5)
        self.cached_embeddings_path = cached_embeddings_path
        # initialize embeddings
        self.embeddings = self.__retrieve_embeddings__()

    def __len__(self):
        return len(self.dataset)

    def __retrieve_embeddings__(self):
        try:
            with open(self.cached_embeddings_path,"rb") as f:
                # Load all embeddings
                embeddings = pickle.load(f) 
        except:
            raise Exception("Error: failed to load embeddings")
        
        # Keep only embeddings for the sequences in self.dataset
        seqs = self.dataset['Sequence'].tolist()
        embeddings = {k:v for k,v in embeddings.items() if k in seqs}
        return embeddings
        
    def __getitem__(self, idx):
        sequence = self.dataset.iloc[idx]['Sequence']
        embedding = self.embeddings[sequence]
        embedding = torch.tensor(embedding, dtype=torch.float32)
        
        # Convert string representations of labels to floats
        label_str = self.dataset.iloc[idx]['Label']
        #label_str = label_str[1:-1]  why this line???
        labels = list(map(int, label_str))
        labels = torch.tensor(labels, dtype=torch.float) 
        assert len(labels)==len(sequence)
        
        return sequence, embedding, labels
    
def get_dataloader(data_path, cached_embeddings_path, max_length=4405, batch_size=1, shuffle=True):
    """
    Creates a DataLoader for the dataset.
    Args:
        data_path (str): Path to the CSV file (train, val, or test).
        batch_size (int): Batch size.
        shuffle (bool): Whether to shuffle the data.
        tokenizer (Tokenizer): tokenizer object for data tokenization
    Returns:
        DataLoader: DataLoader object.
    """
    dataset = DisorderDataset(data_path, cached_embeddings_path=cached_embeddings_path, max_length=max_length)
    return DataLoader(dataset, batch_size=batch_size, collate_fn=custom_collate_fn, shuffle=shuffle)

def check_dataloaders(train_loader, test_loader, max_length=512, checkpoint_dir=''):
    log_update(f'\nBuilt train and test dataloders')
    log_update(f"\tNumber of sequences in the Training DataLoader: {len(train_loader.dataset)}")
    log_update(f"\tNumber of sequences in the Testing DataLoader: {len(test_loader.dataset)}")
    dataloader_overlaps = check_dataloader_overlap(train_loader, test_loader)
    if len(dataloader_overlaps)==0: log_update("\tDataloaders are clean (no overlaps)")
    else: log_update(f"\tWARNING! sequence overlap found: {','.join(dataloader_overlaps)}")
    
    # write length ranges to a text file
    if not(os.path.exists(f'{checkpoint_dir}/batch_diversity')):
        os.mkdir(f'{checkpoint_dir}/batch_diversity')
    
    max_length_violators = []
    for name, dataloader in {'train':train_loader, 'test':test_loader}.items():
        max_length_followed = check_max_length(dataloader, max_length)
        if max_length_followed == False:
            max_length_violators.append(name)
                
    if len(max_length_violators)==0: log_update(f"\tDataloaders follow the max length limit set by user: {max_length}")
    else: log_update(f"\tWARNING! these loaders have sequences longer than max length={max_length}: {','.join(max_length_violators)}")
    
def check_dataloader_overlap(train_loader, test_loader):
    train_seqs = set()
    test_seqs = set()
    for batch_idx, (sequences, _, _) in enumerate(train_loader):
        train_seqs.add(sequences[0])
    for batch_idx, (sequences, _, _) in enumerate(test_loader):
        test_seqs.add(sequences[0])
        
    return train_seqs.intersection(test_seqs)
    
def check_max_length(dataloader, max_length):
    for batch_idx, (sequences, _, _) in enumerate(dataloader):
        if len(sequences[0]) > max_length:
            return False
    
    return True