FART / interpretability.py
yzimmermann's picture
Create interpretability.py
d781678 verified
raw
history blame
5.6 kB
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import SimilarityMaps
from IPython.display import SVG
import io
from PIL import Image
import numpy as np
import rdkit
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import SequenceClassificationExplainer
from transformers import pipeline
model_name = "FartLabs/FART_Augmented"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
cls_explainer = SequenceClassificationExplainer(model, tokenizer)
pipe = pipeline("text-classification", model=model_name)
def get_taste_from_smiles(smiles):
# Original output
output = pipe(smiles)
# Mapping of labels to tastes
taste_labels = ['BITTER', 'SOUR', 'SWEET', 'UMAMI', 'UNDEFINED']
# Extract label and score
label_info = output[0]
label_index = int(label_info['label'].split('_')[1]) # Get the numeric part of the label
score = label_info['score']
# Reassign label
new_label = taste_labels[label_index]
# Format the title string
title_string = f"{new_label} score: {score:.2f}"
# Output the title string
return title_string
def calculate_aspect_ratio(molecule, base_size):
"""
Calculates the canvas width and height based on the molecule's aspect ratio.
Parameters:
- molecule (Mol): RDKit molecule object.
- base_size (int): The base size of the canvas, typically 400.
Returns:
- (int, int): Calculated width and height for the canvas.
"""
conf = molecule.GetConformer()
atom_positions = [conf.GetAtomPosition(i) for i in range(molecule.GetNumAtoms())]
x_coords = [pos.x for pos in atom_positions]
y_coords = [pos.y for pos in atom_positions]
width = max(x_coords) - min(x_coords)
height = max(y_coords) - min(y_coords)
aspect_ratio = width / height if height > 0 else 1
canvas_width = max(base_size, int(base_size * aspect_ratio)) if aspect_ratio > 1 else base_size
canvas_height = max(base_size, int(base_size / aspect_ratio)) if aspect_ratio < 1 else base_size
return canvas_width, canvas_height
def visualize_gradients(smiles, bw=True, padding=0.05):
"""
Visualizes atom-wise gradients or importance scores for a given molecule
based on the SMILES representation as a similarity map.
Parameters:
- smiles (str): The SMILES string of the molecule to visualize.
- bw (bool): If True, renders the molecule in black and white (default is False).
Returns:
- None: Displays the generated similarity map in the notebook.
"""
print(get_taste_from_smiles(smiles))
# Convert SMILES string to RDKit molecule object
molecule = Chem.MolFromSmiles(smiles)
Chem.rdDepictor.Compute2DCoords(molecule)
# Set up canvas size based on aspect ratio
base_size = 400
width, height = calculate_aspect_ratio(molecule, base_size)
d = Draw.MolDraw2DCairo(width, height)
#Draw.SetACS1996Mode(d.drawOptions(),Draw.MeanBondLength(molecule))
d.drawOptions().padding = padding
# Optionally set black and white palette
if bw:
d.drawOptions().useBWAtomPalette()
# Get token importance scores and map to atoms
token_importance = cls_explainer(smiles)
atom_importance = [c[1] for c in token_importance if c[0].isalpha()]
num_atoms = molecule.GetNumAtoms()
atom_importance = atom_importance[:num_atoms]
# Generate and display a similarity map based on atom importance scores
SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d)
# Convert drawing to image and display
d.FinishDrawing()
png_data = d.GetDrawingText()
img = Image(data=png_data)
return img
def save_high_quality_png(smiles, title, bw=True, padding=0.05):
"""
Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule.
Parameters:
- smiles (str): The SMILES string of the molecule to visualize.
- token_importance (list): List of importance scores for each atom.
- bw (bool): If True, renders the molecule in black and white.
- padding (float): Padding for molecule drawing.
- output_file (str): Path to save the high-quality PNG file.
Returns:
- None
"""
# Convert SMILES string to RDKit molecule object
molecule = Chem.MolFromSmiles(smiles)
Chem.rdDepictor.Compute2DCoords(molecule)
# Get token importance scores and map to atoms
token_importance = cls_explainer(smiles)
atom_importance = [c[1] for c in token_importance if c[0].isalpha()]
num_atoms = molecule.GetNumAtoms()
atom_importance = atom_importance[:num_atoms]
# Set a large canvas size for high resolution
d = Draw.MolDraw2DCairo(1500, 1500)
dopts = d.drawOptions()
dopts.padding = padding
dopts.maxFontSize = 2000
dopts.bondLineWidth = 5
# Optionally set black and white palette
if bw:
d.drawOptions().useBWAtomPalette()
# Generate and display a similarity map based on atom importance scores
SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d)
# Draw molecule with color highlights
d.FinishDrawing()
# Save to PNG file with high quality
with open(f"{title}.png", "wb") as png_file:
png_file.write(d.GetDrawingText())
print(f"High-quality PNG file saved as {title}.png")
d.FinishDrawing()
png_data = d.GetDrawingText()
img = Image(data=png_data)
return img