|
from rdkit import Chem |
|
from rdkit.Chem import Draw |
|
from transformers import pipeline |
|
import gradio as gr |
|
from rdkit import Chem |
|
from rdkit.Chem import Draw |
|
from rdkit.Chem.Draw import SimilarityMaps |
|
import io |
|
from PIL import Image |
|
import numpy as np |
|
import rdkit |
|
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
from transformers_interpret import SequenceClassificationExplainer |
|
|
|
model_name = "FartLabs/FART_Augmented" |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
cls_explainer = SequenceClassificationExplainer(model, tokenizer) |
|
|
|
def save_high_quality_png(smiles, title, bw=True, padding=0.05): |
|
""" |
|
Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule. |
|
Parameters: |
|
- smiles (str): The SMILES string of the molecule to visualize. |
|
- token_importance (list): List of importance scores for each atom. |
|
- bw (bool): If True, renders the molecule in black and white. |
|
- padding (float): Padding for molecule drawing. |
|
- output_file (str): Path to save the high-quality PNG file. |
|
Returns: |
|
- None |
|
""" |
|
|
|
|
|
molecule = Chem.MolFromSmiles(smiles) |
|
Chem.rdDepictor.Compute2DCoords(molecule) |
|
|
|
|
|
token_importance = cls_explainer(smiles) |
|
atom_importance = [c[1] for c in token_importance if c[0].isalpha()] |
|
num_atoms = molecule.GetNumAtoms() |
|
atom_importance = atom_importance[:num_atoms] |
|
|
|
|
|
d = Draw.MolDraw2DCairo(1500, 1500) |
|
|
|
dopts = d.drawOptions() |
|
dopts.padding = padding |
|
dopts.maxFontSize = 2000 |
|
dopts.bondLineWidth = 5 |
|
|
|
|
|
if bw: |
|
d.drawOptions().useBWAtomPalette() |
|
|
|
|
|
SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d) |
|
|
|
|
|
d.FinishDrawing() |
|
|
|
|
|
with open(f"{title}.png", "wb") as png_file: |
|
png_file.write(d.GetDrawingText()) |
|
|
|
return None |
|
|
|
model_checkpoint = "FartLabs/FART_Augmented" |
|
classifier = pipeline("text-classification", model=model_checkpoint, top_k=None) |
|
|
|
def process_smiles(smiles, compute_explanation): |
|
|
|
mol = Chem.MolFromSmiles(smiles) |
|
if mol is None: |
|
return "Invalid SMILES", None, "Invalid SMILES" |
|
canonical_smiles = Chem.MolToSmiles(mol) |
|
|
|
|
|
predictions = classifier(canonical_smiles) |
|
|
|
|
|
if compute_explanation: |
|
img_path = "molecule" |
|
filepath= "molecule.png" |
|
save_high_quality_png(smiles, img_path) |
|
else: |
|
filepath = "molecule.png" |
|
img = Draw.MolToImage(mol) |
|
img.save(filepath) |
|
|
|
|
|
prediction_dict = {pred["label"]: pred["score"] for pred in predictions[0]} |
|
|
|
return prediction_dict, filepath, canonical_smiles |
|
|
|
iface = gr.Interface( |
|
fn=process_smiles, |
|
inputs=[ |
|
gr.Textbox(label="Input SMILES", value="O1[C@H](CO)[C@@H](O)[C@H](O)[C@@H](O)[C@H]1O[C@@]2(O[C@@H]([C@@H](O)[C@@H]2O)CO)CO"), |
|
gr.Checkbox(label="Compute Explanation (Takes 60s)", value=False), |
|
], |
|
outputs=[ |
|
gr.Label(num_top_classes=3, label="Classification Probabilities"), |
|
gr.Image(type="filepath", label="Molecule Image"), |
|
gr.Textbox(label="Canonical SMILES") |
|
], |
|
title="FART", |
|
description="Enter a SMILES string to get the taste classification probabilities." |
|
) |
|
|
|
iface.launch() |
|
|