File size: 4,880 Bytes
8087713
 
 
5b5221b
59b4b1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b5221b
59b4b1e
 
 
 
 
 
 
 
 
7d60e00
a117e93
8087713
59b4b1e
8087713
 
 
 
 
 
 
 
 
 
59b4b1e
 
 
 
 
 
 
 
8087713
 
 
 
b6c02bb
8087713
 
 
59b4b1e
 
023c5c8
59b4b1e
8087713
a117e93
e8168dc
a117e93
8087713
ca6b54a
 
 
 
6e03cfc
ca6b54a
 
25bd3fd
ca6b54a
 
6e03cfc
ca6b54a
 
 
8087713
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import pipeline
import gradio as gr
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import SimilarityMaps
import io
from PIL import Image
import numpy as np
import rdkit

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import SequenceClassificationExplainer

model_name = "FartLabs/FART_Augmented"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

cls_explainer = SequenceClassificationExplainer(model, tokenizer)

def save_high_quality_png(smiles, title, bw=True, padding=0.05):
    """
    Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule.
    Parameters:
    - smiles (str): The SMILES string of the molecule to visualize.
    - token_importance (list): List of importance scores for each atom.
    - bw (bool): If True, renders the molecule in black and white.
    - padding (float): Padding for molecule drawing.
    - output_file (str): Path to save the high-quality PNG file.
    Returns:
    - None
    """
    
    # Convert SMILES string to RDKit molecule object
    molecule = Chem.MolFromSmiles(smiles)
    Chem.rdDepictor.Compute2DCoords(molecule)
    
    # Get token importance scores and map to atoms
    token_importance = cls_explainer(smiles)
    atom_importance = [c[1] for c in token_importance if c[0].isalpha()]
    num_atoms = molecule.GetNumAtoms()
    atom_importance = atom_importance[:num_atoms]
    
    # Set a large canvas size for high resolution
    d = Draw.MolDraw2DCairo(1500, 1500)

    dopts = d.drawOptions()
    dopts.padding = padding    
    dopts.maxFontSize = 2000
    dopts.bondLineWidth = 5

    # Optionally set black and white palette
    if bw:
        d.drawOptions().useBWAtomPalette()
    
    # Generate and display a similarity map based on atom importance scores
    SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d)

    # Draw molecule with color highlights
    d.FinishDrawing()
    
    # Save to PNG file with high quality
    with open(f"{title}.png", "wb") as png_file:
        png_file.write(d.GetDrawingText())
    
    return None
    
model_checkpoint = "FartLabs/FART_Augmented"
classifier = pipeline("text-classification", model=model_checkpoint, top_k=None)

def process_smiles(smiles, compute_explanation):
    # Validate and canonicalize SMILES
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return "Invalid SMILES", None, "Invalid SMILES"
    canonical_smiles = Chem.MolToSmiles(mol)
    
    # Predict using the pipeline
    predictions = classifier(canonical_smiles)
    
    # Generate molecule image
    if compute_explanation:
        img_path = "molecule"
        filepath= "molecule.png"
        save_high_quality_png(smiles, img_path)
    else:
        filepath = "molecule.png"
        img = Draw.MolToImage(mol)
        img.save(filepath)
    
    # Convert predictions to a friendly format
    prediction_dict = {pred["label"]: pred["score"] for pred in predictions[0]}
    
    return prediction_dict, filepath, canonical_smiles

iface = gr.Interface(
    fn=process_smiles,
    inputs=[
        gr.Textbox(label="Input SMILES", value="O1[C@H](CO)[C@@H](O)[C@H](O)[C@@H](O)[C@H]1O[C@@]2(O[C@@H]([C@@H](O)[C@@H]2O)CO)CO"),
        gr.Checkbox(label="Display explanation (can take some time)", value=False),
    ],
    outputs=[
        gr.Label(num_top_classes=3, label="Classification Probabilities"),
        gr.Image(type="filepath", label="Molecule Image"),
        gr.Textbox(label="Canonical SMILES")
    ],
    description="""
<section id="molecular-taste-description">
  <h2>Discover Molecular Taste with FART</h2>
  <p>
    At Kvant AI Labs, we just revolutionized taste chemistry with FART (Flavor Analysis and Recognition Transformer), an AI-powered tool designed to predict molecular taste from chemical structure alone. FART delivers predictions for <strong>sweet</strong>, <strong>bitter</strong>, <strong>sour</strong>, and <strong>umami</strong> with over 91% accuracy.
  </p>
  <p>
    Beyond predictions, FART identifies the molecular features driving taste characteristics, enabling actionable insights for flavor innovation. Powered by the ChemBERTa foundation model and trained on the largest molecular taste dataset to date, FART sets a new standard in food science.
  </p>
  <p>
    Learn more about the science behind FART in our <a href="https://chemrxiv.org/engage/chemrxiv/article-details/673a2a3af9980725cf80503c" target="_blank">Pre-print</a>. To generate SMILES, one possible option is this <a href="https://www.cheminfo.org/flavor/malaria/Utilities/SMILES_generator___checker/index.html" target="_blank">tool</a>.
  </p>
</section>
""",
)

iface.launch()