import os import sys import torch import selfies as sf # selfies>=2.1.1 import pickle import pandas as pd import numpy as np from datasets import Dataset from rdkit import Chem from transformers import AutoTokenizer, AutoModel class SELFIES(torch.nn.Module): def __init__(self): super().__init__() self.model = None self.tokenizer = None self.invalid = [] def get_selfies(self, smiles_list): self.invalid = [] spaced_selfies_batch = [] for i, smiles in enumerate(smiles_list): try: selfies = sf.encoder(smiles.rstrip()) except: try: smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles.rstrip())) selfies = sf.encoder(smiles) except: selfies = "[]" self.invalid.append(i) spaced_selfies_batch.append(selfies.replace('][', '] [')) return spaced_selfies_batch def get_embedding(self, selfies): encoding = self.tokenizer(selfies["selfies"], return_tensors='pt', max_length=128, truncation=True, padding='max_length') input_ids = encoding['input_ids'] attention_mask = encoding['attention_mask'] outputs = self.model.encoder(input_ids=input_ids, attention_mask=attention_mask) model_output = outputs.last_hidden_state input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float() sum_embeddings = torch.sum(model_output * input_mask_expanded, 1) sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) model_output = sum_embeddings / sum_mask encoding["embedding"] = model_output return encoding def load(self, checkpoint="bart-2908.pickle"): """ inputs : checkpoint (pickle object) """ self.tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted") self.model = AutoModel.from_pretrained("ibm/materials.selfies-ted") """if os.path.isfile(checkpoint): with open(checkpoint, "rb") as input_file: self.model, self.tokenizer = pickle.load(input_file) for p in sys.path: file = p + "/" + checkpoint if os.path.isfile(file): with open(file, "rb") as input_file: self.model, self.tokenizer = pickle.load(input_file)""" # TODO: remove `use_gpu` argument in validation pipeline def encode(self, smiles_list=[], use_gpu=False, return_tensor=False): """ inputs : checkpoint (pickle object) :return: embedding """ selfies = self.get_selfies(smiles_list) selfies_df = pd.DataFrame(selfies,columns=["selfies"]) data = Dataset.from_pandas(selfies_df) embedding = data.map(self.get_embedding, batched=True, num_proc=1, batch_size=128) emb = np.asarray(embedding["embedding"].copy()) for idx in self.invalid: emb[idx] = np.nan print("Cannot encode {0} to selfies and embedding replaced by NaN".format(smiles_list[idx])) if return_tensor: return torch.tensor(emb) return pd.DataFrame(emb)