|
from rdkit import Chem |
|
from rdkit.Chem import Draw |
|
from rdkit.Chem.Draw import SimilarityMaps |
|
from IPython.display import SVG |
|
import io |
|
from PIL import Image |
|
import numpy as np |
|
import rdkit |
|
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
from transformers_interpret import SequenceClassificationExplainer |
|
from transformers import pipeline |
|
|
|
model_name = "FartLabs/FART_Augmented" |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
cls_explainer = SequenceClassificationExplainer(model, tokenizer) |
|
|
|
pipe = pipeline("text-classification", model=model_name) |
|
|
|
def get_taste_from_smiles(smiles): |
|
|
|
output = pipe(smiles) |
|
|
|
|
|
taste_labels = ['BITTER', 'SOUR', 'SWEET', 'UMAMI', 'UNDEFINED'] |
|
|
|
|
|
label_info = output[0] |
|
label_index = int(label_info['label'].split('_')[1]) |
|
score = label_info['score'] |
|
|
|
|
|
new_label = taste_labels[label_index] |
|
|
|
|
|
title_string = f"{new_label} score: {score:.2f}" |
|
|
|
|
|
return title_string |
|
|
|
def calculate_aspect_ratio(molecule, base_size): |
|
""" |
|
Calculates the canvas width and height based on the molecule's aspect ratio. |
|
|
|
Parameters: |
|
- molecule (Mol): RDKit molecule object. |
|
- base_size (int): The base size of the canvas, typically 400. |
|
|
|
Returns: |
|
- (int, int): Calculated width and height for the canvas. |
|
""" |
|
conf = molecule.GetConformer() |
|
atom_positions = [conf.GetAtomPosition(i) for i in range(molecule.GetNumAtoms())] |
|
x_coords = [pos.x for pos in atom_positions] |
|
y_coords = [pos.y for pos in atom_positions] |
|
width = max(x_coords) - min(x_coords) |
|
height = max(y_coords) - min(y_coords) |
|
aspect_ratio = width / height if height > 0 else 1 |
|
|
|
canvas_width = max(base_size, int(base_size * aspect_ratio)) if aspect_ratio > 1 else base_size |
|
canvas_height = max(base_size, int(base_size / aspect_ratio)) if aspect_ratio < 1 else base_size |
|
|
|
return canvas_width, canvas_height |
|
|
|
def visualize_gradients(smiles, bw=True, padding=0.05): |
|
""" |
|
Visualizes atom-wise gradients or importance scores for a given molecule |
|
based on the SMILES representation as a similarity map. |
|
|
|
Parameters: |
|
- smiles (str): The SMILES string of the molecule to visualize. |
|
- bw (bool): If True, renders the molecule in black and white (default is False). |
|
|
|
Returns: |
|
- None: Displays the generated similarity map in the notebook. |
|
""" |
|
|
|
print(get_taste_from_smiles(smiles)) |
|
|
|
|
|
molecule = Chem.MolFromSmiles(smiles) |
|
Chem.rdDepictor.Compute2DCoords(molecule) |
|
|
|
|
|
base_size = 400 |
|
width, height = calculate_aspect_ratio(molecule, base_size) |
|
d = Draw.MolDraw2DCairo(width, height) |
|
|
|
d.drawOptions().padding = padding |
|
|
|
|
|
if bw: |
|
d.drawOptions().useBWAtomPalette() |
|
|
|
|
|
token_importance = cls_explainer(smiles) |
|
atom_importance = [c[1] for c in token_importance if c[0].isalpha()] |
|
num_atoms = molecule.GetNumAtoms() |
|
atom_importance = atom_importance[:num_atoms] |
|
|
|
|
|
SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d) |
|
|
|
|
|
d.FinishDrawing() |
|
png_data = d.GetDrawingText() |
|
img = Image(data=png_data) |
|
return img |
|
|
|
def save_high_quality_png(smiles, title, bw=True, padding=0.05): |
|
""" |
|
Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule. |
|
|
|
Parameters: |
|
- smiles (str): The SMILES string of the molecule to visualize. |
|
- token_importance (list): List of importance scores for each atom. |
|
- bw (bool): If True, renders the molecule in black and white. |
|
- padding (float): Padding for molecule drawing. |
|
- output_file (str): Path to save the high-quality PNG file. |
|
|
|
Returns: |
|
- None |
|
""" |
|
|
|
|
|
molecule = Chem.MolFromSmiles(smiles) |
|
Chem.rdDepictor.Compute2DCoords(molecule) |
|
|
|
|
|
token_importance = cls_explainer(smiles) |
|
atom_importance = [c[1] for c in token_importance if c[0].isalpha()] |
|
num_atoms = molecule.GetNumAtoms() |
|
atom_importance = atom_importance[:num_atoms] |
|
|
|
|
|
d = Draw.MolDraw2DCairo(1500, 1500) |
|
|
|
dopts = d.drawOptions() |
|
dopts.padding = padding |
|
dopts.maxFontSize = 2000 |
|
dopts.bondLineWidth = 5 |
|
|
|
|
|
if bw: |
|
d.drawOptions().useBWAtomPalette() |
|
|
|
|
|
SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d) |
|
|
|
|
|
d.FinishDrawing() |
|
|
|
|
|
with open(f"{title}.png", "wb") as png_file: |
|
png_file.write(d.GetDrawingText()) |
|
|
|
print(f"High-quality PNG file saved as {title}.png") |
|
d.FinishDrawing() |
|
png_data = d.GetDrawingText() |
|
img = Image(data=png_data) |
|
return img |