File size: 2,251 Bytes
d781678
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee0536e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import SimilarityMaps
from IPython.display import SVG
import io
from PIL import Image
import numpy as np
import rdkit

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import SequenceClassificationExplainer
from transformers import pipeline

model_name = "FartLabs/FART_Augmented"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

cls_explainer = SequenceClassificationExplainer(model, tokenizer)

def save_high_quality_png(smiles, title, bw=True, padding=0.05):
    """
    Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule.

    Parameters:
    - smiles (str): The SMILES string of the molecule to visualize.
    - token_importance (list): List of importance scores for each atom.
    - bw (bool): If True, renders the molecule in black and white.
    - padding (float): Padding for molecule drawing.
    - output_file (str): Path to save the high-quality PNG file.

    Returns:
    - None
    """
    
    # Convert SMILES string to RDKit molecule object
    molecule = Chem.MolFromSmiles(smiles)
    Chem.rdDepictor.Compute2DCoords(molecule)
    
    # Get token importance scores and map to atoms
    token_importance = cls_explainer(smiles)
    atom_importance = [c[1] for c in token_importance if c[0].isalpha()]
    num_atoms = molecule.GetNumAtoms()
    atom_importance = atom_importance[:num_atoms]
    
    # Set a large canvas size for high resolution
    d = Draw.MolDraw2DCairo(1500, 1500)

    dopts = d.drawOptions()
    dopts.padding = padding    
    dopts.maxFontSize = 2000
    dopts.bondLineWidth = 5

    # Optionally set black and white palette
    if bw:
        d.drawOptions().useBWAtomPalette()
    
    # Generate and display a similarity map based on atom importance scores
    SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d)

    # Draw molecule with color highlights
    d.FinishDrawing()
    
    # Save to PNG file with high quality
    with open(f"{title}.png", "wb") as png_file:
        png_file.write(d.GetDrawingText())
    
    return None