|
import os |
|
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler |
|
import torch |
|
from rdkit import Chem, DataStructs |
|
import pandas as pd |
|
import pickle as pkl |
|
import numpy as np |
|
from sklearn.preprocessing import StandardScaler |
|
import sys |
|
|
|
|
|
from utils.parallel import * |
|
from utils.chem import * |
|
from utils.sequence import * |
|
|
|
|
|
class Preprocessor: |
|
def __init__( |
|
self, |
|
path: str, |
|
radius: int = 2, |
|
n_bits: int = 1024, |
|
aa_embedding: str = "prottrans_t5_xl_u50", |
|
num_workers: int = 1, |
|
): |
|
self.path = path |
|
self.radius = radius |
|
self.n_bits = n_bits |
|
self.aa_embedding = aa_embedding |
|
self.num_workers = num_workers |
|
|
|
self.data = None |
|
self.fp = None |
|
self.aa = None |
|
self.split = None |
|
self.label = None |
|
|
|
self.load_data() |
|
self.process_data() |
|
|
|
def load_data(self): |
|
if os.path.isfile(self.path): |
|
self.data = pd.read_csv(self.path, low_memory=False) |
|
else: |
|
raise ValueError("No data file found in the specified path") |
|
|
|
def process_data(self): |
|
if "smiles" not in self.data.columns: |
|
raise ValueError("No smiles column found in the data") |
|
if "sequence" not in self.data.columns: |
|
raise ValueError("No sequence column found in the data") |
|
|
|
smiles = self.data.smiles.tolist() |
|
seq = self.data.sequence.tolist() |
|
|
|
if "split" in self.data.columns: |
|
self.split = self.data.split.tolist() |
|
if "label" in self.data.columns: |
|
self.label = self.data.label.tolist() |
|
|
|
if self.num_workers > 1: |
|
mols = parallel(get_mols, self.num_workers, smiles) |
|
fps = parallel(get_fp, self.num_workers, mols, self.radius, self.n_bits) |
|
else: |
|
mols = get_mols(smiles) |
|
|
|
fps = get_fp(mols, self.radius, self.n_bits) |
|
|
|
self.fp = store_fp(fps, self.n_bits) |
|
self.aa = encode_sequences(seq, self.aa_embedding) |
|
|
|
def return_generator( |
|
self, |
|
device, |
|
batch_size: int = 512, |
|
include_negatives: bool = False, |
|
shuffle: bool = True, |
|
validation_split: float = None, |
|
) -> (DataLoader, DataLoader): |
|
|
|
if self.split is None and self.label is None: |
|
print("No split or label columns found in the dataset") |
|
dataset = MolAADataset(device, self.fp, self.aa) |
|
elif self.split is not None: |
|
print("Splitting data into train and validation sets from the dataset without considering labels") |
|
train_fp, train_aa, val_fp, val_aa = [], [], [], [] |
|
for i in range(len(self.fp)): |
|
if self.split[i] == "train": |
|
train_fp.append(self.fp[i]) |
|
train_aa.append(self.aa[i]) |
|
|
|
elif self.split[i] == "val": |
|
val_fp.append(self.fp[i]) |
|
val_aa.append(self.aa[i]) |
|
|
|
train_dataset = MolAADataset(device, train_fp, train_aa) |
|
val_dataset = MolAADataset(device, val_fp, val_aa) |
|
|
|
print(f"Train: {len(train_fp)}, Validation: {len(val_fp)}") |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle) |
|
validation_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle) |
|
return train_loader, validation_loader |
|
|
|
else: |
|
print("Splitting data into train and validation sets from the dataset") |
|
train_fp, train_aa, val_fp, val_aa = [], [], [], [] |
|
for i in range(len(self.fp)): |
|
if self.split[i] == "train": |
|
if include_negatives and self.label[i] == 0: |
|
train_fp.append(self.fp[i]) |
|
train_aa.append(self.aa[i] * -1) |
|
elif self.label[i] == 1: |
|
train_fp.append(self.fp[i]) |
|
train_aa.append(self.aa[i]) |
|
elif self.split[i] == "val": |
|
if include_negatives and self.label[i] == 0: |
|
val_fp.append(self.fp[i]) |
|
val_aa.append(self.aa[i] * -1) |
|
elif self.label[i] == 1: |
|
val_fp.append(self.fp[i]) |
|
val_aa.append(self.aa[i]) |
|
|
|
train_dataset = MolAADataset(device, train_fp, train_aa) |
|
val_dataset = MolAADataset(device, val_fp, val_aa) |
|
|
|
print(f"Train: {len(train_fp)}, Validation: {len(val_fp)}") |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle) |
|
validation_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle) |
|
return train_loader, validation_loader |
|
|
|
if validation_split is not None: |
|
print("Splitting data into train and validation by fractionation from the dataset") |
|
dataset_size = len(dataset) |
|
indices = list(range(dataset_size)) |
|
split = int(np.floor(validation_split * dataset_size)) |
|
if shuffle: |
|
np.random.shuffle(indices) |
|
train_indices, val_indices = indices[split:], indices[:split] |
|
|
|
train_sampler = SubsetRandomSampler(train_indices) |
|
valid_sampler = SubsetRandomSampler(val_indices) |
|
|
|
train_loader = DataLoader( |
|
dataset, batch_size=batch_size, sampler=train_sampler |
|
) |
|
validation_loader = DataLoader( |
|
dataset, batch_size=batch_size, sampler=valid_sampler |
|
) |
|
return train_loader, validation_loader |
|
|
|
else: |
|
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) |
|
return train_loader, None |
|
|
|
|
|
class MolAADataset(Dataset): |
|
def __init__(self, device, mol, aa): |
|
self.mol = mol |
|
self.aa = aa |
|
self.device = device |
|
|
|
def __len__(self): |
|
""" |
|
Method necessary for Pytorch training |
|
""" |
|
return len(self.mol) |
|
|
|
def __getitem__(self, idx): |
|
""" |
|
Method necessary for Pytorch training |
|
""" |
|
mol_sample = torch.tensor(self.mol[idx], dtype=torch.float32) |
|
aa_sample = torch.tensor(self.aa[idx], dtype=torch.float32) |
|
|
|
mol_sample = mol_sample.to(self.device) |
|
aa_sample = aa_sample.to(self.device) |
|
|
|
return mol_sample, aa_sample |
|
|