yzimmermann commited on
Commit
59b4b1e
·
verified ·
1 Parent(s): ee0536e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -6
app.py CHANGED
@@ -2,12 +2,74 @@ from rdkit import Chem
2
  from rdkit.Chem import Draw
3
  from transformers import pipeline
4
  import gradio as gr
5
- from interpretability import save_high_quality_png
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
 
 
 
 
 
 
 
 
 
7
  model_checkpoint = "FartLabs/FART_Augmented"
8
  classifier = pipeline("text-classification", model=model_checkpoint, top_k=None)
9
 
10
- def process_smiles(smiles):
11
  # Validate and canonicalize SMILES
12
  mol = Chem.MolFromSmiles(smiles)
13
  if mol is None:
@@ -18,9 +80,14 @@ def process_smiles(smiles):
18
  predictions = classifier(canonical_smiles)
19
 
20
  # Generate molecule image
21
- img_path = "molecule"
22
- filepath= "molecule.png"
23
- save_high_quality_png(smiles, img_path)
 
 
 
 
 
24
 
25
  # Convert predictions to a friendly format
26
  prediction_dict = {pred["label"]: pred["score"] for pred in predictions[0]}
@@ -29,7 +96,10 @@ def process_smiles(smiles):
29
 
30
  iface = gr.Interface(
31
  fn=process_smiles,
32
- inputs=gr.Textbox(label="Input SMILES", value="O1[C@H](CO)[C@@H](O)[C@H](O)[C@@H](O)[C@H]1O[C@@]2(O[C@@H]([C@@H](O)[C@@H]2O)CO)CO"),
 
 
 
33
  outputs=[
34
  gr.Label(num_top_classes=3, label="Classification Probabilities"),
35
  gr.Image(type="filepath", label="Molecule Image"),
 
2
  from rdkit.Chem import Draw
3
  from transformers import pipeline
4
  import gradio as gr
5
+ from rdkit import Chem
6
+ from rdkit.Chem import Draw
7
+ from rdkit.Chem.Draw import SimilarityMaps
8
+ import io
9
+ from PIL import Image
10
+ import numpy as np
11
+ import rdkit
12
+
13
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
14
+ from transformers_interpret import SequenceClassificationExplainer
15
+
16
+ model_name = "FartLabs/FART_Augmented"
17
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+
20
+ cls_explainer = SequenceClassificationExplainer(model, tokenizer)
21
+
22
+ def save_high_quality_png(smiles, title, bw=True, padding=0.05):
23
+ """
24
+ Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule.
25
+ Parameters:
26
+ - smiles (str): The SMILES string of the molecule to visualize.
27
+ - token_importance (list): List of importance scores for each atom.
28
+ - bw (bool): If True, renders the molecule in black and white.
29
+ - padding (float): Padding for molecule drawing.
30
+ - output_file (str): Path to save the high-quality PNG file.
31
+ Returns:
32
+ - None
33
+ """
34
+
35
+ # Convert SMILES string to RDKit molecule object
36
+ molecule = Chem.MolFromSmiles(smiles)
37
+ Chem.rdDepictor.Compute2DCoords(molecule)
38
+
39
+ # Get token importance scores and map to atoms
40
+ token_importance = cls_explainer(smiles)
41
+ atom_importance = [c[1] for c in token_importance if c[0].isalpha()]
42
+ num_atoms = molecule.GetNumAtoms()
43
+ atom_importance = atom_importance[:num_atoms]
44
+
45
+ # Set a large canvas size for high resolution
46
+ d = Draw.MolDraw2DCairo(1500, 1500)
47
+
48
+ dopts = d.drawOptions()
49
+ dopts.padding = padding
50
+ dopts.maxFontSize = 2000
51
+ dopts.bondLineWidth = 5
52
+
53
+ # Optionally set black and white palette
54
+ if bw:
55
+ d.drawOptions().useBWAtomPalette()
56
+
57
+ # Generate and display a similarity map based on atom importance scores
58
+ SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d)
59
 
60
+ # Draw molecule with color highlights
61
+ d.FinishDrawing()
62
+
63
+ # Save to PNG file with high quality
64
+ with open(f"{title}.png", "wb") as png_file:
65
+ png_file.write(d.GetDrawingText())
66
+
67
+ return None
68
+
69
  model_checkpoint = "FartLabs/FART_Augmented"
70
  classifier = pipeline("text-classification", model=model_checkpoint, top_k=None)
71
 
72
+ def process_smiles(smiles, compute_explanation):
73
  # Validate and canonicalize SMILES
74
  mol = Chem.MolFromSmiles(smiles)
75
  if mol is None:
 
80
  predictions = classifier(canonical_smiles)
81
 
82
  # Generate molecule image
83
+ if compute_explanation:
84
+ img_path = "molecule"
85
+ filepath= "molecule.png"
86
+ save_high_quality_png(smiles, img_path)
87
+ else:
88
+ filepath = "molecule.png"
89
+ img = Draw.MolToImage(mol)
90
+ img.save(filepath)
91
 
92
  # Convert predictions to a friendly format
93
  prediction_dict = {pred["label"]: pred["score"] for pred in predictions[0]}
 
96
 
97
  iface = gr.Interface(
98
  fn=process_smiles,
99
+ inputs=[
100
+ gr.Textbox(label="Input SMILES", value="O1[C@H](CO)[C@@H](O)[C@H](O)[C@@H](O)[C@H]1O[C@@]2(O[C@@H]([C@@H](O)[C@@H]2O)CO)CO"),
101
+ gr.Checkbox(label="Compute Explanation (Takes 60s)", value=False),
102
+ ],
103
  outputs=[
104
  gr.Label(num_top_classes=3, label="Classification Probabilities"),
105
  gr.Image(type="filepath", label="Molecule Image"),