yzimmermann commited on
Commit
d781678
·
verified ·
1 Parent(s): 7d60e00

Create interpretability.py

Browse files
Files changed (1) hide show
  1. interpretability.py +163 -0
interpretability.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit import Chem
2
+ from rdkit.Chem import Draw
3
+ from rdkit.Chem.Draw import SimilarityMaps
4
+ from IPython.display import SVG
5
+ import io
6
+ from PIL import Image
7
+ import numpy as np
8
+ import rdkit
9
+
10
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
11
+ from transformers_interpret import SequenceClassificationExplainer
12
+ from transformers import pipeline
13
+
14
+ model_name = "FartLabs/FART_Augmented"
15
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+
18
+ cls_explainer = SequenceClassificationExplainer(model, tokenizer)
19
+
20
+ pipe = pipeline("text-classification", model=model_name)
21
+
22
+ def get_taste_from_smiles(smiles):
23
+ # Original output
24
+ output = pipe(smiles)
25
+
26
+ # Mapping of labels to tastes
27
+ taste_labels = ['BITTER', 'SOUR', 'SWEET', 'UMAMI', 'UNDEFINED']
28
+
29
+ # Extract label and score
30
+ label_info = output[0]
31
+ label_index = int(label_info['label'].split('_')[1]) # Get the numeric part of the label
32
+ score = label_info['score']
33
+
34
+ # Reassign label
35
+ new_label = taste_labels[label_index]
36
+
37
+ # Format the title string
38
+ title_string = f"{new_label} score: {score:.2f}"
39
+
40
+ # Output the title string
41
+ return title_string
42
+
43
+ def calculate_aspect_ratio(molecule, base_size):
44
+ """
45
+ Calculates the canvas width and height based on the molecule's aspect ratio.
46
+
47
+ Parameters:
48
+ - molecule (Mol): RDKit molecule object.
49
+ - base_size (int): The base size of the canvas, typically 400.
50
+
51
+ Returns:
52
+ - (int, int): Calculated width and height for the canvas.
53
+ """
54
+ conf = molecule.GetConformer()
55
+ atom_positions = [conf.GetAtomPosition(i) for i in range(molecule.GetNumAtoms())]
56
+ x_coords = [pos.x for pos in atom_positions]
57
+ y_coords = [pos.y for pos in atom_positions]
58
+ width = max(x_coords) - min(x_coords)
59
+ height = max(y_coords) - min(y_coords)
60
+ aspect_ratio = width / height if height > 0 else 1
61
+
62
+ canvas_width = max(base_size, int(base_size * aspect_ratio)) if aspect_ratio > 1 else base_size
63
+ canvas_height = max(base_size, int(base_size / aspect_ratio)) if aspect_ratio < 1 else base_size
64
+
65
+ return canvas_width, canvas_height
66
+
67
+ def visualize_gradients(smiles, bw=True, padding=0.05):
68
+ """
69
+ Visualizes atom-wise gradients or importance scores for a given molecule
70
+ based on the SMILES representation as a similarity map.
71
+
72
+ Parameters:
73
+ - smiles (str): The SMILES string of the molecule to visualize.
74
+ - bw (bool): If True, renders the molecule in black and white (default is False).
75
+
76
+ Returns:
77
+ - None: Displays the generated similarity map in the notebook.
78
+ """
79
+
80
+ print(get_taste_from_smiles(smiles))
81
+
82
+ # Convert SMILES string to RDKit molecule object
83
+ molecule = Chem.MolFromSmiles(smiles)
84
+ Chem.rdDepictor.Compute2DCoords(molecule)
85
+
86
+ # Set up canvas size based on aspect ratio
87
+ base_size = 400
88
+ width, height = calculate_aspect_ratio(molecule, base_size)
89
+ d = Draw.MolDraw2DCairo(width, height)
90
+ #Draw.SetACS1996Mode(d.drawOptions(),Draw.MeanBondLength(molecule))
91
+ d.drawOptions().padding = padding
92
+
93
+ # Optionally set black and white palette
94
+ if bw:
95
+ d.drawOptions().useBWAtomPalette()
96
+
97
+ # Get token importance scores and map to atoms
98
+ token_importance = cls_explainer(smiles)
99
+ atom_importance = [c[1] for c in token_importance if c[0].isalpha()]
100
+ num_atoms = molecule.GetNumAtoms()
101
+ atom_importance = atom_importance[:num_atoms]
102
+
103
+ # Generate and display a similarity map based on atom importance scores
104
+ SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d)
105
+
106
+ # Convert drawing to image and display
107
+ d.FinishDrawing()
108
+ png_data = d.GetDrawingText()
109
+ img = Image(data=png_data)
110
+ return img
111
+
112
+ def save_high_quality_png(smiles, title, bw=True, padding=0.05):
113
+ """
114
+ Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule.
115
+
116
+ Parameters:
117
+ - smiles (str): The SMILES string of the molecule to visualize.
118
+ - token_importance (list): List of importance scores for each atom.
119
+ - bw (bool): If True, renders the molecule in black and white.
120
+ - padding (float): Padding for molecule drawing.
121
+ - output_file (str): Path to save the high-quality PNG file.
122
+
123
+ Returns:
124
+ - None
125
+ """
126
+
127
+ # Convert SMILES string to RDKit molecule object
128
+ molecule = Chem.MolFromSmiles(smiles)
129
+ Chem.rdDepictor.Compute2DCoords(molecule)
130
+
131
+ # Get token importance scores and map to atoms
132
+ token_importance = cls_explainer(smiles)
133
+ atom_importance = [c[1] for c in token_importance if c[0].isalpha()]
134
+ num_atoms = molecule.GetNumAtoms()
135
+ atom_importance = atom_importance[:num_atoms]
136
+
137
+ # Set a large canvas size for high resolution
138
+ d = Draw.MolDraw2DCairo(1500, 1500)
139
+
140
+ dopts = d.drawOptions()
141
+ dopts.padding = padding
142
+ dopts.maxFontSize = 2000
143
+ dopts.bondLineWidth = 5
144
+
145
+ # Optionally set black and white palette
146
+ if bw:
147
+ d.drawOptions().useBWAtomPalette()
148
+
149
+ # Generate and display a similarity map based on atom importance scores
150
+ SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d)
151
+
152
+ # Draw molecule with color highlights
153
+ d.FinishDrawing()
154
+
155
+ # Save to PNG file with high quality
156
+ with open(f"{title}.png", "wb") as png_file:
157
+ png_file.write(d.GetDrawingText())
158
+
159
+ print(f"High-quality PNG file saved as {title}.png")
160
+ d.FinishDrawing()
161
+ png_data = d.GetDrawingText()
162
+ img = Image(data=png_data)
163
+ return img