|
import pandas as pd |
|
import torch |
|
import pickle |
|
import os |
|
from torch.utils.data import DataLoader, Dataset |
|
from fuson_plm.utils.logging import log_update |
|
|
|
|
|
|
|
def custom_collate_fn(batch): |
|
""" |
|
Custom collate function to handle batches with strings and tensors. |
|
|
|
Args: |
|
batch (list): List of tuples returned by __getitem__. |
|
|
|
Returns: |
|
tuple: (sequences, embeddings, labels) |
|
- sequences: List of strings |
|
- embeddings: Tensor of shape (batch_size, embedding_dim) |
|
- labels: Tensor of shape (batch_size, sequence_length) |
|
""" |
|
sequences, embeddings, labels = zip(*batch) |
|
|
|
|
|
embeddings = torch.stack(embeddings, dim=0) |
|
labels = torch.stack(labels, dim=0) |
|
|
|
|
|
sequences = list(sequences) |
|
|
|
return sequences, embeddings, labels |
|
|
|
class DisorderDataset(Dataset): |
|
def __init__(self, csv_file_path, cached_embeddings_path=None, max_length=4405): |
|
super(DisorderDataset, self).__init__() |
|
self.dataset = pd.read_csv(csv_file_path) |
|
self.cached_embeddings_path = cached_embeddings_path |
|
|
|
self.embeddings = self.__retrieve_embeddings__() |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __retrieve_embeddings__(self): |
|
try: |
|
with open(self.cached_embeddings_path,"rb") as f: |
|
|
|
embeddings = pickle.load(f) |
|
except: |
|
raise Exception("Error: failed to load embeddings") |
|
|
|
|
|
seqs = self.dataset['Sequence'].tolist() |
|
embeddings = {k:v for k,v in embeddings.items() if k in seqs} |
|
return embeddings |
|
|
|
def __getitem__(self, idx): |
|
sequence = self.dataset.iloc[idx]['Sequence'] |
|
embedding = self.embeddings[sequence] |
|
embedding = torch.tensor(embedding, dtype=torch.float32) |
|
|
|
|
|
label_str = self.dataset.iloc[idx]['Label'] |
|
|
|
labels = list(map(int, label_str)) |
|
labels = torch.tensor(labels, dtype=torch.float) |
|
assert len(labels)==len(sequence) |
|
|
|
return sequence, embedding, labels |
|
|
|
def get_dataloader(data_path, cached_embeddings_path, max_length=4405, batch_size=1, shuffle=True): |
|
""" |
|
Creates a DataLoader for the dataset. |
|
Args: |
|
data_path (str): Path to the CSV file (train, val, or test). |
|
batch_size (int): Batch size. |
|
shuffle (bool): Whether to shuffle the data. |
|
tokenizer (Tokenizer): tokenizer object for data tokenization |
|
Returns: |
|
DataLoader: DataLoader object. |
|
""" |
|
dataset = DisorderDataset(data_path, cached_embeddings_path=cached_embeddings_path, max_length=max_length) |
|
return DataLoader(dataset, batch_size=batch_size, collate_fn=custom_collate_fn, shuffle=shuffle) |
|
|
|
def check_dataloaders(train_loader, test_loader, max_length=512, checkpoint_dir=''): |
|
log_update(f'\nBuilt train and test dataloders') |
|
log_update(f"\tNumber of sequences in the Training DataLoader: {len(train_loader.dataset)}") |
|
log_update(f"\tNumber of sequences in the Testing DataLoader: {len(test_loader.dataset)}") |
|
dataloader_overlaps = check_dataloader_overlap(train_loader, test_loader) |
|
if len(dataloader_overlaps)==0: log_update("\tDataloaders are clean (no overlaps)") |
|
else: log_update(f"\tWARNING! sequence overlap found: {','.join(dataloader_overlaps)}") |
|
|
|
|
|
if not(os.path.exists(f'{checkpoint_dir}/batch_diversity')): |
|
os.mkdir(f'{checkpoint_dir}/batch_diversity') |
|
|
|
max_length_violators = [] |
|
for name, dataloader in {'train':train_loader, 'test':test_loader}.items(): |
|
max_length_followed = check_max_length(dataloader, max_length) |
|
if max_length_followed == False: |
|
max_length_violators.append(name) |
|
|
|
if len(max_length_violators)==0: log_update(f"\tDataloaders follow the max length limit set by user: {max_length}") |
|
else: log_update(f"\tWARNING! these loaders have sequences longer than max length={max_length}: {','.join(max_length_violators)}") |
|
|
|
def check_dataloader_overlap(train_loader, test_loader): |
|
train_seqs = set() |
|
test_seqs = set() |
|
for batch_idx, (sequences, _, _) in enumerate(train_loader): |
|
train_seqs.add(sequences[0]) |
|
for batch_idx, (sequences, _, _) in enumerate(test_loader): |
|
test_seqs.add(sequences[0]) |
|
|
|
return train_seqs.intersection(test_seqs) |
|
|
|
def check_max_length(dataloader, max_length): |
|
for batch_idx, (sequences, _, _) in enumerate(dataloader): |
|
if len(sequences[0]) > max_length: |
|
return False |
|
|
|
return True |
|
|