Spaces:
Running
Running
""" | |
© Battelle Memorial Institute 2023 | |
Made available under the GNU General Public License v 2.0 | |
BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY | |
FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN | |
OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES | |
PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED | |
OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF | |
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS | |
TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE | |
PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, | |
REPAIR OR CORRECTION. | |
""" | |
import re | |
import torch | |
import pandas as pd | |
from rdkit import Chem | |
from rdkit.Chem.SaltRemover import SaltRemover | |
class InvalidSmile(Exception): | |
pass | |
def load_vocab(vocab_file_name): | |
""" | |
Load an existing vocabulary from a file. Assumes a single | |
token definition per line of the file. | |
Parameters | |
---------- | |
vocab_file_name : str | |
The file name of the vocabulary to load. | |
Returns | |
------- | |
vocab_dict : dict | |
A dict of tokens as the keys and the corresponding | |
token index as the items. | |
""" | |
# Get vocabulary | |
vocab = pd.read_csv(vocab_file_name, header=None)[0].to_list() | |
vocab_dict = {v: ind for ind, v in enumerate(vocab)} | |
return vocab_dict | |
def smiles_tokenizer(smiles): | |
""" | |
Tokenize a SMILES string. | |
Parameters | |
---------- | |
smiles : str | |
A SMILES string to turn into tokens. | |
Returns | |
------- | |
tokens : list | |
A list of tokens after tokenizing the input string. | |
""" | |
pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" | |
regex = re.compile(pattern) | |
tokens = [token for token in regex.findall(smiles)] | |
# check if the smiles string had extra characters not recognized by regex | |
# solution based on https://stackoverflow.com/a/3879574 | |
if len("".join(tokens)) < len(smiles): | |
raise Exception( | |
"Input smiles string contained invalid characters." | |
) | |
return tokens | |
def smiles_to_tensor( | |
smiles, vocab_dict, max_seq_len, desalt=True, canonical=True, isomeric=True | |
): | |
""" | |
Converts a SMILES string to a tensor using the provided vocabulary. | |
Parameters | |
---------- | |
smiles : str | |
A SMILES string to convert to a tensor. | |
vocab_dict : dict | |
A dictionary of SMILES tokens and integer value as the dictionary key | |
and item, respectively. | |
max_seq_len : int | |
The maximum sequence length allowed for SMILES strings. Smaller | |
strings are padded to the maximum length using the [PAD] token | |
from the vocabulary provided. | |
desalt : bool, optional | |
Flag for removing salts and solvents from SMILES string, by default True. | |
canonical : bool, optional | |
Flag enabling the conversion of the SMILES to canonical form, by default True. | |
isomeric : bool, optional | |
Flag enabling the conversion of the SMILES to isomeric form, by default True. | |
Returns | |
------- | |
smiles_ten_long : tensor | |
A tensor representing the converted SMILES string based on the provided | |
vocabulary with shape (1, max_seq_len). | |
""" | |
# Initialize the salt/solvent remover | |
remover = SaltRemover() | |
# Convert the SMILES to molecule | |
mol = Chem.MolFromSmiles(smiles) | |
if mol is None: | |
raise InvalidSmile('Molecule could not be constructed from smile string') | |
# Remove the salts/solvents | |
if desalt: | |
mol = remover.StripMol(mol, dontRemoveEverything=True) | |
# Convert back to SMILES | |
smiles = Chem.MolToSmiles(mol, canonical=canonical, isomericSmiles=isomeric) | |
# Tokenize the SMILES | |
smiles_tok = smiles_tokenizer(smiles) | |
tok = [vocab_dict["[CLS]"], vocab_dict["[EDGE]"]] | |
tok += [vocab_dict[x] for x in smiles_tok] | |
tok += [vocab_dict["[EDGE]"]] | |
smiles_ten = torch.tensor(tok, dtype=torch.long) | |
smiles_ten_long = ( | |
torch.ones((1, max_seq_len), dtype=torch.long) * vocab_dict["[PAD]"] | |
) | |
smiles_ten_long[0, : smiles_ten.shape[0]] = smiles_ten | |
return smiles_ten_long | |