|
from rdkit import Chem |
|
from rdkit.Chem import Draw |
|
from rdkit.Chem.Draw import SimilarityMaps |
|
from IPython.display import SVG |
|
import io |
|
from PIL import Image |
|
import numpy as np |
|
import rdkit |
|
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
from transformers_interpret import SequenceClassificationExplainer |
|
from transformers import pipeline |
|
|
|
model_name = "FartLabs/FART_Augmented" |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
cls_explainer = SequenceClassificationExplainer(model, tokenizer) |
|
|
|
def save_high_quality_png(smiles, title, bw=True, padding=0.05): |
|
""" |
|
Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule. |
|
|
|
Parameters: |
|
- smiles (str): The SMILES string of the molecule to visualize. |
|
- token_importance (list): List of importance scores for each atom. |
|
- bw (bool): If True, renders the molecule in black and white. |
|
- padding (float): Padding for molecule drawing. |
|
- output_file (str): Path to save the high-quality PNG file. |
|
|
|
Returns: |
|
- None |
|
""" |
|
|
|
|
|
molecule = Chem.MolFromSmiles(smiles) |
|
Chem.rdDepictor.Compute2DCoords(molecule) |
|
|
|
|
|
token_importance = cls_explainer(smiles) |
|
atom_importance = [c[1] for c in token_importance if c[0].isalpha()] |
|
num_atoms = molecule.GetNumAtoms() |
|
atom_importance = atom_importance[:num_atoms] |
|
|
|
|
|
d = Draw.MolDraw2DCairo(1500, 1500) |
|
|
|
dopts = d.drawOptions() |
|
dopts.padding = padding |
|
dopts.maxFontSize = 2000 |
|
dopts.bondLineWidth = 5 |
|
|
|
|
|
if bw: |
|
d.drawOptions().useBWAtomPalette() |
|
|
|
|
|
SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d) |
|
|
|
|
|
d.FinishDrawing() |
|
|
|
|
|
with open(f"{title}.png", "wb") as png_file: |
|
png_file.write(d.GetDrawingText()) |
|
|
|
return None |