yzimmermann commited on
Commit
0a18df7
·
verified ·
1 Parent(s): 59b4b1e

Delete interpretability.py

Browse files
Files changed (1) hide show
  1. interpretability.py +0 -67
interpretability.py DELETED
@@ -1,67 +0,0 @@
1
- from rdkit import Chem
2
- from rdkit.Chem import Draw
3
- from rdkit.Chem.Draw import SimilarityMaps
4
- from IPython.display import SVG
5
- import io
6
- from PIL import Image
7
- import numpy as np
8
- import rdkit
9
-
10
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
11
- from transformers_interpret import SequenceClassificationExplainer
12
- from transformers import pipeline
13
-
14
- model_name = "FartLabs/FART_Augmented"
15
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
16
- tokenizer = AutoTokenizer.from_pretrained(model_name)
17
-
18
- cls_explainer = SequenceClassificationExplainer(model, tokenizer)
19
-
20
- def save_high_quality_png(smiles, title, bw=True, padding=0.05):
21
- """
22
- Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule.
23
-
24
- Parameters:
25
- - smiles (str): The SMILES string of the molecule to visualize.
26
- - token_importance (list): List of importance scores for each atom.
27
- - bw (bool): If True, renders the molecule in black and white.
28
- - padding (float): Padding for molecule drawing.
29
- - output_file (str): Path to save the high-quality PNG file.
30
-
31
- Returns:
32
- - None
33
- """
34
-
35
- # Convert SMILES string to RDKit molecule object
36
- molecule = Chem.MolFromSmiles(smiles)
37
- Chem.rdDepictor.Compute2DCoords(molecule)
38
-
39
- # Get token importance scores and map to atoms
40
- token_importance = cls_explainer(smiles)
41
- atom_importance = [c[1] for c in token_importance if c[0].isalpha()]
42
- num_atoms = molecule.GetNumAtoms()
43
- atom_importance = atom_importance[:num_atoms]
44
-
45
- # Set a large canvas size for high resolution
46
- d = Draw.MolDraw2DCairo(1500, 1500)
47
-
48
- dopts = d.drawOptions()
49
- dopts.padding = padding
50
- dopts.maxFontSize = 2000
51
- dopts.bondLineWidth = 5
52
-
53
- # Optionally set black and white palette
54
- if bw:
55
- d.drawOptions().useBWAtomPalette()
56
-
57
- # Generate and display a similarity map based on atom importance scores
58
- SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d)
59
-
60
- # Draw molecule with color highlights
61
- d.FinishDrawing()
62
-
63
- # Save to PNG file with high quality
64
- with open(f"{title}.png", "wb") as png_file:
65
- png_file.write(d.GetDrawingText())
66
-
67
- return None