FupBERT_Space / smiles.py
c-dunlap's picture
Initial app file upload
6e8698e
raw
history blame
3.58 kB
import re
import torch
import pandas as pd
from rdkit import Chem
from rdkit.Chem.SaltRemover import SaltRemover
class InvalidSmile(Exception):
pass
def load_vocab(vocab_file_name):
"""
Load an existing vocabulary from a file. Assumes a single
token definition per line of the file.
Parameters
----------
vocab_file_name : str
The file name of the vocabulary to load.
Returns
-------
vocab_dict : dict
A dict of tokens as the keys and the corresponding
token index as the items.
"""
# Get vocabulary
vocab = pd.read_csv(vocab_file_name, header=None)[0].to_list()
vocab_dict = {v: ind for ind, v in enumerate(vocab)}
return vocab_dict
def smiles_tokenizer(smiles):
"""
Tokenize a SMILES string.
Parameters
----------
smiles : str
A SMILES string to turn into tokens.
Returns
-------
tokens : list
A list of tokens after tokenizing the input string.
"""
pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
regex = re.compile(pattern)
tokens = [token for token in regex.findall(smiles)]
# check if the smiles string had extra characters not recognized by regex
# solution based on https://stackoverflow.com/a/3879574
if len("".join(tokens)) < len(smiles):
raise Exception(
"Input smiles string contained invalid characters."
)
return tokens
def smiles_to_tensor(
smiles, vocab_dict, max_seq_len, desalt=True, canonical=True, isomeric=True
):
"""
Converts a SMILES string to a tensor using the provided vocabulary.
Parameters
----------
smiles : str
A SMILES string to convert to a tensor.
vocab_dict : dict
A dictionary of SMILES tokens and integer value as the dictionary key
and item, respectively.
max_seq_len : int
The maximum sequence length allowed for SMILES strings. Smaller
strings are padded to the maximum length using the [PAD] token
from the vocabulary provided.
desalt : bool, optional
Flag for removing salts and solvents from SMILES string, by default True.
canonical : bool, optional
Flag enabling the conversion of the SMILES to canonical form, by default True.
isomeric : bool, optional
Flag enabling the conversion of the SMILES to isomeric form, by default True.
Returns
-------
smiles_ten_long : tensor
A tensor representing the converted SMILES string based on the provided
vocabulary with shape (1, max_seq_len).
"""
# Initialize the salt/solvent remover
remover = SaltRemover()
# Convert the SMILES to molecule
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise InvalidSmile('Molecule could not be constructed from smile string')
# Remove the salts/solvents
if desalt:
mol = remover.StripMol(mol, dontRemoveEverything=True)
# Convert back to SMILES
smiles = Chem.MolToSmiles(mol, canonical=canonical, isomericSmiles=isomeric)
# Tokenize the SMILES
smiles_tok = smiles_tokenizer(smiles)
tok = [vocab_dict["[CLS]"], vocab_dict["[EDGE]"]]
tok += [vocab_dict[x] for x in smiles_tok]
tok += [vocab_dict["[EDGE]"]]
smiles_ten = torch.tensor(tok, dtype=torch.long)
smiles_ten_long = (
torch.ones((1, max_seq_len), dtype=torch.long) * vocab_dict["[PAD]"]
)
smiles_ten_long[0, : smiles_ten.shape[0]] = smiles_ten
return smiles_ten_long