|
from rdkit import Chem |
|
from rdkit.Chem import Draw |
|
from transformers import pipeline |
|
import gradio as gr |
|
|
|
model_checkpoint = "yzimmermann/FART" |
|
classifier = pipeline("text-classification", model=model_checkpoint, top_k=None) |
|
|
|
def process_smiles(smiles): |
|
|
|
mol = Chem.MolFromSmiles(smiles) |
|
if mol is None: |
|
return "Invalid SMILES", None, "Invalid SMILES" |
|
canonical_smiles = Chem.MolToSmiles(mol) |
|
|
|
|
|
predictions = classifier(canonical_smiles) |
|
|
|
|
|
img_path = "molecule.png" |
|
img = Draw.MolToImage(mol) |
|
img.save(img_path) |
|
|
|
|
|
prediction_dict = {pred["label"]: pred["score"] for pred in predictions[0]} |
|
|
|
return prediction_dict, img_path, canonical_smiles |
|
|
|
iface = gr.Interface( |
|
fn=process_smiles, |
|
inputs=gr.Textbox(label="Input SMILES"), |
|
outputs=[ |
|
gr.Label(num_top_classes=3, label="Classification Probabilities"), |
|
gr.Image(type="filepath", label="Molecule Image"), |
|
gr.Textbox(label="Canonical SMILES") |
|
], |
|
title="FART", |
|
description="Enter a SMILES string to get the taste classification probabilities." |
|
) |
|
|
|
iface.launch() |
|
|