File size: 1,360 Bytes
8087713 5b5221b 7d60e00 a117e93 8087713 b4077e0 8087713 a117e93 e8168dc a117e93 8087713 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import pipeline
import gradio as gr
model_checkpoint = "FartLabs/FART_Augmented"
classifier = pipeline("text-classification", model=model_checkpoint, top_k=None)
def process_smiles(smiles):
# Validate and canonicalize SMILES
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return "Invalid SMILES", None, "Invalid SMILES"
canonical_smiles = Chem.MolToSmiles(mol)
# Predict using the pipeline
predictions = classifier(canonical_smiles)
# Generate molecule image
img_path = "molecule.png"
img = Draw.MolToImage(mol)
img.save(img_path)
# Convert predictions to a friendly format
prediction_dict = {pred["label"]: pred["score"] for pred in predictions[0]}
return prediction_dict, img_path, canonical_smiles
iface = gr.Interface(
fn=process_smiles,
inputs=gr.Textbox(label="Input SMILES", value="O1[C@H](CO)[C@@H](O)[C@H](O)[C@@H](O)[C@H]1O[C@@]2(O[C@@H]([C@@H](O)[C@@H]2O)CO)CO"),
outputs=[
gr.Label(num_top_classes=3, label="Classification Probabilities"),
gr.Image(type="filepath", label="Molecule Image"),
gr.Textbox(label="Canonical SMILES")
],
title="FART",
description="Enter a SMILES string to get the taste classification probabilities."
)
iface.launch()
|