File size: 1,261 Bytes
8087713
 
 
5b5221b
 
8087713
a117e93
8087713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a117e93
8087713
a117e93
e8168dc
a117e93
8087713
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import pipeline
import gradio as gr

model_checkpoint = "yzimmermann/FART"
classifier = pipeline("text-classification", model=model_checkpoint, top_k=None)

def process_smiles(smiles):
    # Validate and canonicalize SMILES
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return "Invalid SMILES", None, "Invalid SMILES"
    canonical_smiles = Chem.MolToSmiles(mol)
    
    # Predict using the pipeline
    predictions = classifier(canonical_smiles)
    
    # Generate molecule image
    img_path = "molecule.png"
    img = Draw.MolToImage(mol)
    img.save(img_path)
    
    # Convert predictions to a friendly format
    prediction_dict = {pred["label"]: pred["score"] for pred in predictions[0]}
    
    return prediction_dict, img_path, canonical_smiles

iface = gr.Interface(
    fn=process_smiles,
    inputs=gr.Textbox(label="Input SMILES"),
    outputs=[
        gr.Label(num_top_classes=3, label="Classification Probabilities"),
        gr.Image(type="filepath", label="Molecule Image"),
        gr.Textbox(label="Canonical SMILES")
    ],
    title="FART",
    description="Enter a SMILES string to get the taste classification probabilities."
)

iface.launch()