import re import torch import pandas as pd from rdkit import Chem from rdkit.Chem.SaltRemover import SaltRemover class InvalidSmile(Exception): pass def load_vocab(vocab_file_name): """ Load an existing vocabulary from a file. Assumes a single token definition per line of the file. Parameters ---------- vocab_file_name : str The file name of the vocabulary to load. Returns ------- vocab_dict : dict A dict of tokens as the keys and the corresponding token index as the items. """ # Get vocabulary vocab = pd.read_csv(vocab_file_name, header=None)[0].to_list() vocab_dict = {v: ind for ind, v in enumerate(vocab)} return vocab_dict def smiles_tokenizer(smiles): """ Tokenize a SMILES string. Parameters ---------- smiles : str A SMILES string to turn into tokens. Returns ------- tokens : list A list of tokens after tokenizing the input string. """ pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" regex = re.compile(pattern) tokens = [token for token in regex.findall(smiles)] # check if the smiles string had extra characters not recognized by regex # solution based on https://stackoverflow.com/a/3879574 if len("".join(tokens)) < len(smiles): raise Exception( "Input smiles string contained invalid characters." ) return tokens def smiles_to_tensor( smiles, vocab_dict, max_seq_len, desalt=True, canonical=True, isomeric=True ): """ Converts a SMILES string to a tensor using the provided vocabulary. Parameters ---------- smiles : str A SMILES string to convert to a tensor. vocab_dict : dict A dictionary of SMILES tokens and integer value as the dictionary key and item, respectively. max_seq_len : int The maximum sequence length allowed for SMILES strings. Smaller strings are padded to the maximum length using the [PAD] token from the vocabulary provided. desalt : bool, optional Flag for removing salts and solvents from SMILES string, by default True. canonical : bool, optional Flag enabling the conversion of the SMILES to canonical form, by default True. isomeric : bool, optional Flag enabling the conversion of the SMILES to isomeric form, by default True. Returns ------- smiles_ten_long : tensor A tensor representing the converted SMILES string based on the provided vocabulary with shape (1, max_seq_len). """ # Initialize the salt/solvent remover remover = SaltRemover() # Convert the SMILES to molecule mol = Chem.MolFromSmiles(smiles) if mol is None: raise InvalidSmile('Molecule could not be constructed from smile string') # Remove the salts/solvents if desalt: mol = remover.StripMol(mol, dontRemoveEverything=True) # Convert back to SMILES smiles = Chem.MolToSmiles(mol, canonical=canonical, isomericSmiles=isomeric) # Tokenize the SMILES smiles_tok = smiles_tokenizer(smiles) tok = [vocab_dict["[CLS]"], vocab_dict["[EDGE]"]] tok += [vocab_dict[x] for x in smiles_tok] tok += [vocab_dict["[EDGE]"]] smiles_ten = torch.tensor(tok, dtype=torch.long) smiles_ten_long = ( torch.ones((1, max_seq_len), dtype=torch.long) * vocab_dict["[PAD]"] ) smiles_ten_long[0, : smiles_ten.shape[0]] = smiles_ten return smiles_ten_long