{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# granite.materials.smi-TED - Encoder & Decoder" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../inference')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# materials.smi-ted (smi-ted)\n", "from smi_ted_light.load import load_smi_ted\n", "\n", "# Data\n", "import pandas as pd\n", "import numpy as np\n", "import torch\n", "\n", "# Chemistry\n", "from rdkit import Chem\n", "from rdkit.Chem import PandasTools\n", "from rdkit.Chem import Descriptors\n", "from rdkit.Chem import AllChem\n", "from rdkit.DataStructs import FingerprintSimilarity\n", "from rdkit.DataStructs import TanimotoSimilarity" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# function to canonicalize SMILES\n", "def normalize_smiles(smi, canonical=True, isomeric=False):\n", " try:\n", " normalized = Chem.MolToSmiles(\n", " Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric\n", " )\n", " except:\n", " normalized = None\n", " return normalized\n", "\n", "# function to calculate pairwise Tanimoto similarity\n", "def calculate_tanimoto_similarities(fps1, fps2):\n", " similarities = []\n", " for i in range(len(fps1)):\n", " sim = TanimotoSimilarity(fps1[i], fps2[i])\n", " similarities.append(sim)\n", " return similarities" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load smi-ted" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Random Seed: 12345\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Vocab size: 2393\n", "[INFERENCE MODE - smi-ted-Light]\n" ] } ], "source": [ "model_smi_ted = load_smi_ted(\n", " folder='../inference/smi_ted_light',\n", " ckpt_filename='smi-ted-Light_40.pt'\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Dataset" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "df_moses = pd.read_csv(\"./data/moses_test.csv\", nrows=1000)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1000, 1)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SMILES
0CC1C2CCC(C2)C1CN(CCO)C(=O)c1ccc(Cl)cc1
1COc1ccc(-c2cc(=O)c3c(O)c(OC)c(OC)cc3o2)cc1O
2CCOC(=O)c1ncn2c1CN(C)C(=O)c1cc(F)ccc1-2
3Clc1ccccc1-c1nc(-c2ccncc2)no1
4CC(C)(Oc1ccc(Cl)cc1)C(=O)OCc1cccc(CO)n1
\n", "
" ], "text/plain": [ " SMILES\n", "0 CC1C2CCC(C2)C1CN(CCO)C(=O)c1ccc(Cl)cc1\n", "1 COc1ccc(-c2cc(=O)c3c(O)c(OC)c(OC)cc3o2)cc1O\n", "2 CCOC(=O)c1ncn2c1CN(C)C(=O)c1cc(F)ccc1-2\n", "3 Clc1ccccc1-c1nc(-c2ccncc2)no1\n", "4 CC(C)(Oc1ccc(Cl)cc1)C(=O)OCc1cccc(CO)n1" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_moses['SMILES'] = df_moses['SMILES'].apply(normalize_smiles)\n", "df_test_normalized = df_moses.dropna()\n", "print(df_test_normalized.shape)\n", "df_test_normalized.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Encode SMILES - smi-ted" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 10/10 [00:07<00:00, 1.42it/s]\n" ] } ], "source": [ "with torch.no_grad():\n", " encode_embeddings = model_smi_ted.encode(df_moses['SMILES'], return_torch=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Decode smi-ted embeddings into SMILES" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " decoded_smiles = model_smi_ted.decode(encode_embeddings)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['CC1C2CCC(C2)C1CN(CCO)C(=O)c1ccc(Cl)cc1',\n", " 'COc1ccc(-c2cc(=O)c3c(O)c(OC)c(OC)cc3o2)cc1O',\n", " 'CCOC(=O)c1ncn2c1CN(C)C(=O)c1cc(F)ccc1-2',\n", " 'Clc1ccccc1-c1nc(-c2ccncc2)no1',\n", " 'CC(C)(Oc1ccc(Cl)cc1)C(=O)OCc1cccc(CO)n1']" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "decoded_smiles[0:5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compare similarities" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean Tanimoto Similarity: 1.00\n" ] } ], "source": [ "# Convert SMILES to RDKit molecule objects\n", "mols1 = [Chem.MolFromSmiles(smiles) for smiles in df_moses['SMILES'].to_list()]\n", "mols2 = [Chem.MolFromSmiles(smiles) for smiles in decoded_smiles]\n", "\n", "# Compute fingerprints for each molecule\n", "fps1 = [AllChem.GetMorganFingerprintAsBitVect(mol, 2) for mol in mols1]\n", "fps2 = [AllChem.GetMorganFingerprintAsBitVect(mol, 2) for mol in mols2]\n", "\n", "# Calculate Tanimoto similarities\n", "tanimoto_similarities = calculate_tanimoto_similarities(fps1, fps2)\n", "\n", "# Calculate the mean similarity\n", "mean_similarity = np.mean(tanimoto_similarities)\n", "\n", "# Print the mean similarity\n", "print(f\"Mean Tanimoto Similarity: {mean_similarity:.2f}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }