FART / app.py
yzimmermann's picture
Update app.py
7d60e00 verified
raw
history blame
1.36 kB
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import pipeline
import gradio as gr
model_checkpoint = "FartLabs/FART_Augmented"
classifier = pipeline("text-classification", model=model_checkpoint, top_k=None)
def process_smiles(smiles):
# Validate and canonicalize SMILES
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return "Invalid SMILES", None, "Invalid SMILES"
canonical_smiles = Chem.MolToSmiles(mol)
# Predict using the pipeline
predictions = classifier(canonical_smiles)
# Generate molecule image
img_path = "molecule.png"
img = Draw.MolToImage(mol)
img.save(img_path)
# Convert predictions to a friendly format
prediction_dict = {pred["label"]: pred["score"] for pred in predictions[0]}
return prediction_dict, img_path, canonical_smiles
iface = gr.Interface(
fn=process_smiles,
inputs=gr.Textbox(label="Input SMILES", value="O1[C@H](CO)[C@@H](O)[C@H](O)[C@@H](O)[C@H]1O[C@@]2(O[C@@H]([C@@H](O)[C@@H]2O)CO)CO"),
outputs=[
gr.Label(num_top_classes=3, label="Classification Probabilities"),
gr.Image(type="filepath", label="Molecule Image"),
gr.Textbox(label="Canonical SMILES")
],
title="FART",
description="Enter a SMILES string to get the taste classification probabilities."
)
iface.launch()