Spaces:
Running
Running
import numpy as np | |
import pandas as pd | |
import re | |
import selfies as sf | |
import torch | |
from rdkit import Chem | |
from rdkit.Chem import DataStructs, AllChem, Descriptors, QED, Draw | |
from rdkit.Chem.Crippen import MolLogP | |
from rdkit.Contrib.SA_Score import sascorer | |
from transformers import BartForConditionalGeneration, AutoTokenizer | |
from transformers.modeling_outputs import BaseModelOutput | |
gen_tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted") | |
gen_model = BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted") | |
# Function to display molecule image from SMILES | |
def smiles_to_image(smiles): | |
mol = Chem.MolFromSmiles(smiles) | |
return Draw.MolToImage(mol) if mol else None | |
def calculate_properties(smiles): | |
mol = Chem.MolFromSmiles(smiles) | |
if mol: | |
qed = QED.qed(mol) | |
logp = MolLogP(mol) | |
sa = sascorer.calculateScore(mol) | |
wt = Descriptors.MolWt(mol) | |
return qed, sa, logp, wt | |
return None, None, None, None | |
# Function to calculate Tanimoto similarity | |
def calculate_tanimoto(smiles1, smiles2): | |
mol1 = Chem.MolFromSmiles(smiles1) | |
mol2 = Chem.MolFromSmiles(smiles2) | |
if mol1 and mol2: | |
fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2) | |
fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2) | |
return round(DataStructs.FingerprintSimilarity(fp1, fp2), 2) | |
return None | |
def _perturb_latent(latent_vecs, noise_scale=0.5): | |
return ( | |
torch.tensor( | |
np.random.uniform(0, 1, latent_vecs.shape) * noise_scale, | |
dtype=torch.float32, | |
) | |
+ latent_vecs | |
) | |
def _encode(selfies): | |
encoding = gen_tokenizer( | |
selfies, | |
return_tensors='pt', | |
max_length=128, | |
truncation=True, | |
padding='max_length', | |
) | |
input_ids = encoding['input_ids'] | |
attention_mask = encoding['attention_mask'] | |
outputs = gen_model.model.encoder( | |
input_ids=input_ids, attention_mask=attention_mask | |
) | |
model_output = outputs.last_hidden_state | |
return model_output, attention_mask | |
def _generate(latent_vector, mask): | |
encoder_outputs = BaseModelOutput(latent_vector) | |
decoder_output = gen_model.generate( | |
encoder_outputs=encoder_outputs, | |
attention_mask=mask, | |
max_new_tokens=64, | |
do_sample=True, | |
top_k=5, | |
top_p=0.95, | |
num_return_sequences=1, | |
) | |
selfies = gen_tokenizer.batch_decode(decoder_output, skip_special_tokens=True) | |
return [sf.decoder(re.sub(r'\]\s*(.*?)\s*\[', r']\1[', i)) for i in selfies] | |
# Function to generate canonical SMILES and molecule image | |
def generate_canonical(smiles): | |
s = sf.encoder(smiles) | |
selfie = s.replace("][", "] [") | |
latent_vec, mask = _encode([selfie]) | |
gen_mol = None | |
for i in range(5, 51): | |
print("Searching Latent space") | |
noise = i / 10 | |
perturbed_latent = _perturb_latent(latent_vec, noise_scale=noise) | |
gen = _generate(perturbed_latent, mask) | |
mol = Chem.MolFromSmiles(gen[0]) | |
if mol: | |
gen_mol = Chem.MolToSmiles(mol) | |
if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)): | |
break | |
else: | |
print('Abnormal molecule:', gen[0]) | |
if gen_mol: | |
# Calculate properties for ref and gen molecules | |
print("calculating properties") | |
ref_properties = calculate_properties(smiles) | |
gen_properties = calculate_properties(gen_mol) | |
tanimoto_similarity = calculate_tanimoto(smiles, gen_mol) | |
# Prepare the table with ref mol and gen mol | |
data = { | |
"Property": ["QED", "SA", "LogP", "Mol Wt", "Tanimoto Similarity"], | |
"Reference Mol": [ | |
ref_properties[0], | |
ref_properties[1], | |
ref_properties[2], | |
ref_properties[3], | |
tanimoto_similarity, | |
], | |
"Generated Mol": [ | |
gen_properties[0], | |
gen_properties[1], | |
gen_properties[2], | |
gen_properties[3], | |
"", | |
], | |
} | |
df = pd.DataFrame(data) | |
# Display molecule image of canonical smiles | |
print("Getting image") | |
mol_image = smiles_to_image(gen_mol) | |
return df, gen_mol, mol_image | |
return "Invalid SMILES", None, None | |