import os import glob import torch import random import selfies as sf from rdkit import Chem from datasets import load_dataset from transformers import T5EncoderModel from torch.utils.data import DistributedSampler, DataLoader, Dataset def get_dataloader(dataset, batchsize, rank, world_size): sampler = DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=True ) def collate(batch): selfies_ids = [i["selfies_ids"] for i in batch] caption_state = [i["caption_state"] for i in batch] caption_mask = [i["caption_mask"] for i in batch] corrupted_selfies_ids = [i["corrupted_selfies_ids"] for i in batch] return ( torch.concat(selfies_ids, dim=0), torch.concat(caption_state, dim=0), torch.concat(caption_mask, dim=0), torch.concat(corrupted_selfies_ids, dim=0), ) dataloader = DataLoader( dataset, batch_size=batchsize, shuffle=False, collate_fn=collate, sampler=sampler, ) def cycle(): ec = 0 while True: dataloader.sampler.set_epoch(ec) for i in dataloader: yield i ec += 1 return iter(cycle()) class Lang2molDataset_train(Dataset): def __init__( self, dir, tokenizer, split, dataset_name, pre=None, prob=0, load_state=True, corrupt_prob=0.4, token_max_length=256, ): super().__init__() self.dir = dir self.tokenizer = tokenizer self.split = split self.pre = pre self.prob = prob self.corrupt_prob = corrupt_prob self.token_max_length = token_max_length self.dataset_name = dataset_name self.ori_data = self.create_data() self.load_state = load_state self.model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol") self.model.to("cuda") self.model.eval() def create_data(self): try: dataset = load_dataset( self.dataset_name, token=True, split=self.split, ).sort("id") except: dataset = load_dataset( self.dataset_name, use_auth_token=True, split=self.split, ).sort("id") return [ (int(sample_id), sample_selfies, sample_caption, sample_smiles) for (sample_id, sample_selfies, sample_caption, sample_smiles) in zip( dataset["id"], dataset["selfies"], dataset["caption"], dataset["smiles"], ) ] def __len__(self): return len(self.ori_data) def permute(self, selfies): if random.random() < self.prob: return changeorder(selfies, shuffle=True) else: return selfies def __getitem__(self, idx): data = self.ori_data[idx] sample = { "id": data[0], "selfies": self.permute(data[1]), "caption": data[2], "smiles": data[3], } # Molecules output_molecule = self.tokenizer( sample["selfies"], max_length=self.token_max_length, truncation=True, padding="max_length", add_special_tokens=True, return_tensors="pt", return_attention_mask=True, ) sample["selfies_ids"] = output_molecule["input_ids"] sample["corrupted_selfies_ids"] = sample["selfies_ids"] # Captions output_caption = self.tokenizer( sample["caption"], max_length=self.token_max_length, truncation=True, padding="max_length", add_special_tokens=True, return_tensors="pt", return_attention_mask=True, ) sample["caption_state"] = self.model( input_ids=output_caption["input_ids"].to("cuda"), attention_mask=output_caption["attention_mask"].to("cuda"), ).last_hidden_state sample["caption_mask"] = output_caption["attention_mask"] return sample class Lang2molDataset_eval(Dataset): def __init__( self, dir, tokenizer, split, dataset_name, pre=None, prob=0, load_state=True, corrupt_prob=0.4, token_max_length=512, ): super().__init__() self.dir = dir self.tokenizer = tokenizer self.split = split self.pre = pre self.prob = prob self.corrupt_prob = corrupt_prob self.token_max_length = token_max_length self.dataset_name = dataset_name self.ori_data = self.create_data() self.load_state = load_state self.model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol") self.model.to("cuda") self.model.eval() def create_data(self): try: dataset = load_dataset( self.dataset_name, token=True, split=self.split, ).sort("id") except: dataset = load_dataset( self.dataset_name, use_auth_token=True, split=self.split, ).sort("id") return [ (int(sample_id), sample_selfies, sample_caption, sample_smiles) for (sample_id, sample_selfies, sample_caption, sample_smiles) in zip( dataset["id"], dataset["selfies"], dataset["caption"], dataset["smiles"], ) ] def __len__(self): return len(self.ori_data) def permute(self, selfies): if random.random() < self.prob: return changeorder(selfies, shuffle=True) else: return selfies def __getitem__(self, idx): data = self.ori_data[idx] sample = { "id": data[0], "selfies": self.permute(data[1]), "caption": data[2], "smiles": data[3], } output_caption = self.tokenizer( sample["caption"], max_length=self.token_max_length, truncation=True, padding="max_length", add_special_tokens=True, return_tensors="pt", return_attention_mask=True, ) sample["caption_state"] = self.model( input_ids=output_caption["input_ids"].to("cuda"), attention_mask=output_caption["attention_mask"].to("cuda"), ).last_hidden_state sample["caption_mask"] = output_caption["attention_mask"] return sample class Lang2molDataset_submission(Dataset): def __init__( self, dir, tokenizer, split, dataset_name, pre=None, prob=0, load_state=True, corrupt_prob=0.4, token_max_length=256, ): super().__init__() self.dir = dir self.tokenizer = tokenizer self.split = split self.pre = pre self.prob = prob self.corrupt_prob = corrupt_prob self.token_max_length = token_max_length self.dataset_name = dataset_name self.ori_data = self.create_data() self.load_state = load_state self.model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol") self.model.to("cuda") self.model.eval() def create_data(self): try: dataset = load_dataset( self.dataset_name, token=True, split=self.split, ) except: dataset = load_dataset( self.dataset_name, use_auth_token=True, split=self.split, ) return [sample_caption for sample_caption in dataset["caption"]] def __len__(self): return len(self.ori_data) def permute(self, selfies): if random.random() < self.prob: return changeorder(selfies, shuffle=True) else: return selfies def __getitem__(self, idx): sample = {"caption": self.ori_data[idx]} # Captions output_caption = self.tokenizer( sample["caption"], max_length=self.token_max_length, truncation=True, padding="max_length", add_special_tokens=True, return_tensors="pt", return_attention_mask=True, ) sample["caption_state"] = self.model( input_ids=output_caption["input_ids"].to("cuda"), attention_mask=output_caption["attention_mask"].to("cuda"), ).last_hidden_state sample["caption_mask"] = output_caption["attention_mask"] return sample def changeorder(selfies, shuffle): smiles = sf.encoder(selfies) mol = Chem.MolFromSmiles(smiles) if mol is None: return selfies Chem.Kekulize(mol) atom_indices = [atom.GetIdx() for atom in mol.GetAtoms()] if shuffle: random.shuffle(atom_indices) reordered_mol = Chem.RenumberAtoms(mol, atom_indices) new_smiles = Chem.MolToSmiles(reordered_mol, kekuleSmiles=True) new_selfies = sf.decoder(new_smiles) return new_selfies