File size: 1,261 Bytes
8087713 5b5221b 8087713 a117e93 8087713 a117e93 8087713 a117e93 e8168dc a117e93 8087713 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import pipeline
import gradio as gr
model_checkpoint = "yzimmermann/FART"
classifier = pipeline("text-classification", model=model_checkpoint, top_k=None)
def process_smiles(smiles):
# Validate and canonicalize SMILES
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return "Invalid SMILES", None, "Invalid SMILES"
canonical_smiles = Chem.MolToSmiles(mol)
# Predict using the pipeline
predictions = classifier(canonical_smiles)
# Generate molecule image
img_path = "molecule.png"
img = Draw.MolToImage(mol)
img.save(img_path)
# Convert predictions to a friendly format
prediction_dict = {pred["label"]: pred["score"] for pred in predictions[0]}
return prediction_dict, img_path, canonical_smiles
iface = gr.Interface(
fn=process_smiles,
inputs=gr.Textbox(label="Input SMILES"),
outputs=[
gr.Label(num_top_classes=3, label="Classification Probabilities"),
gr.Image(type="filepath", label="Molecule Image"),
gr.Textbox(label="Canonical SMILES")
],
title="FART",
description="Enter a SMILES string to get the taste classification probabilities."
)
iface.launch()
|