yzimmermann commited on
Commit
9c5a50f
·
verified ·
1 Parent(s): 4ad1832

Update interpretability.py

Browse files
Files changed (1) hide show
  1. interpretability.py +0 -92
interpretability.py CHANGED
@@ -17,98 +17,6 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
 
18
  cls_explainer = SequenceClassificationExplainer(model, tokenizer)
19
 
20
- pipe = pipeline("text-classification", model=model_name)
21
-
22
- def get_taste_from_smiles(smiles):
23
- # Original output
24
- output = pipe(smiles)
25
-
26
- # Mapping of labels to tastes
27
- taste_labels = ['BITTER', 'SOUR', 'SWEET', 'UMAMI', 'UNDEFINED']
28
-
29
- # Extract label and score
30
- label_info = output[0]
31
- label_index = int(label_info['label'].split('_')[1]) # Get the numeric part of the label
32
- score = label_info['score']
33
-
34
- # Reassign label
35
- new_label = taste_labels[label_index]
36
-
37
- # Format the title string
38
- title_string = f"{new_label} score: {score:.2f}"
39
-
40
- # Output the title string
41
- return title_string
42
-
43
- def calculate_aspect_ratio(molecule, base_size):
44
- """
45
- Calculates the canvas width and height based on the molecule's aspect ratio.
46
-
47
- Parameters:
48
- - molecule (Mol): RDKit molecule object.
49
- - base_size (int): The base size of the canvas, typically 400.
50
-
51
- Returns:
52
- - (int, int): Calculated width and height for the canvas.
53
- """
54
- conf = molecule.GetConformer()
55
- atom_positions = [conf.GetAtomPosition(i) for i in range(molecule.GetNumAtoms())]
56
- x_coords = [pos.x for pos in atom_positions]
57
- y_coords = [pos.y for pos in atom_positions]
58
- width = max(x_coords) - min(x_coords)
59
- height = max(y_coords) - min(y_coords)
60
- aspect_ratio = width / height if height > 0 else 1
61
-
62
- canvas_width = max(base_size, int(base_size * aspect_ratio)) if aspect_ratio > 1 else base_size
63
- canvas_height = max(base_size, int(base_size / aspect_ratio)) if aspect_ratio < 1 else base_size
64
-
65
- return canvas_width, canvas_height
66
-
67
- def visualize_gradients(smiles, bw=True, padding=0.05):
68
- """
69
- Visualizes atom-wise gradients or importance scores for a given molecule
70
- based on the SMILES representation as a similarity map.
71
-
72
- Parameters:
73
- - smiles (str): The SMILES string of the molecule to visualize.
74
- - bw (bool): If True, renders the molecule in black and white (default is False).
75
-
76
- Returns:
77
- - None: Displays the generated similarity map in the notebook.
78
- """
79
-
80
- print(get_taste_from_smiles(smiles))
81
-
82
- # Convert SMILES string to RDKit molecule object
83
- molecule = Chem.MolFromSmiles(smiles)
84
- Chem.rdDepictor.Compute2DCoords(molecule)
85
-
86
- # Set up canvas size based on aspect ratio
87
- base_size = 400
88
- width, height = calculate_aspect_ratio(molecule, base_size)
89
- d = Draw.MolDraw2DCairo(width, height)
90
- #Draw.SetACS1996Mode(d.drawOptions(),Draw.MeanBondLength(molecule))
91
- d.drawOptions().padding = padding
92
-
93
- # Optionally set black and white palette
94
- if bw:
95
- d.drawOptions().useBWAtomPalette()
96
-
97
- # Get token importance scores and map to atoms
98
- token_importance = cls_explainer(smiles)
99
- atom_importance = [c[1] for c in token_importance if c[0].isalpha()]
100
- num_atoms = molecule.GetNumAtoms()
101
- atom_importance = atom_importance[:num_atoms]
102
-
103
- # Generate and display a similarity map based on atom importance scores
104
- SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d)
105
-
106
- # Convert drawing to image and display
107
- d.FinishDrawing()
108
- png_data = d.GetDrawingText()
109
- img = Image(data=png_data)
110
- return img
111
-
112
  def save_high_quality_png(smiles, title, bw=True, padding=0.05):
113
  """
114
  Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule.
 
17
 
18
  cls_explainer = SequenceClassificationExplainer(model, tokenizer)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def save_high_quality_png(smiles, title, bw=True, padding=0.05):
21
  """
22
  Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule.