File size: 5,603 Bytes
d781678
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import SimilarityMaps
from IPython.display import SVG
import io
from PIL import Image
import numpy as np
import rdkit

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import SequenceClassificationExplainer
from transformers import pipeline

model_name = "FartLabs/FART_Augmented"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

cls_explainer = SequenceClassificationExplainer(model, tokenizer)

pipe = pipeline("text-classification", model=model_name)

def get_taste_from_smiles(smiles): 
    # Original output
    output = pipe(smiles)

    # Mapping of labels to tastes
    taste_labels = ['BITTER', 'SOUR', 'SWEET', 'UMAMI', 'UNDEFINED']

    # Extract label and score
    label_info = output[0]
    label_index = int(label_info['label'].split('_')[1])  # Get the numeric part of the label
    score = label_info['score']

    # Reassign label
    new_label = taste_labels[label_index]

    # Format the title string
    title_string = f"{new_label} score: {score:.2f}"

    # Output the title string
    return title_string

def calculate_aspect_ratio(molecule, base_size):
    """
    Calculates the canvas width and height based on the molecule's aspect ratio.

    Parameters:
    - molecule (Mol): RDKit molecule object.
    - base_size (int): The base size of the canvas, typically 400.

    Returns:
    - (int, int): Calculated width and height for the canvas.
    """
    conf = molecule.GetConformer()
    atom_positions = [conf.GetAtomPosition(i) for i in range(molecule.GetNumAtoms())]
    x_coords = [pos.x for pos in atom_positions]
    y_coords = [pos.y for pos in atom_positions]
    width = max(x_coords) - min(x_coords)
    height = max(y_coords) - min(y_coords)
    aspect_ratio = width / height if height > 0 else 1

    canvas_width = max(base_size, int(base_size * aspect_ratio)) if aspect_ratio > 1 else base_size
    canvas_height = max(base_size, int(base_size / aspect_ratio)) if aspect_ratio < 1 else base_size

    return canvas_width, canvas_height

def visualize_gradients(smiles, bw=True, padding=0.05):
    """
    Visualizes atom-wise gradients or importance scores for a given molecule 
    based on the SMILES representation as a similarity map.

    Parameters:
    - smiles (str): The SMILES string of the molecule to visualize.
    - bw (bool): If True, renders the molecule in black and white (default is False).

    Returns:
    - None: Displays the generated similarity map in the notebook.
    """
    
    print(get_taste_from_smiles(smiles))

    # Convert SMILES string to RDKit molecule object
    molecule = Chem.MolFromSmiles(smiles)
    Chem.rdDepictor.Compute2DCoords(molecule)
    
    # Set up canvas size based on aspect ratio
    base_size = 400
    width, height = calculate_aspect_ratio(molecule, base_size)
    d = Draw.MolDraw2DCairo(width, height)
    #Draw.SetACS1996Mode(d.drawOptions(),Draw.MeanBondLength(molecule))
    d.drawOptions().padding = padding
    
    # Optionally set black and white palette
    if bw:
        d.drawOptions().useBWAtomPalette()
    
    # Get token importance scores and map to atoms
    token_importance = cls_explainer(smiles)
    atom_importance = [c[1] for c in token_importance if c[0].isalpha()]
    num_atoms = molecule.GetNumAtoms()
    atom_importance = atom_importance[:num_atoms]
    
    # Generate and display a similarity map based on atom importance scores
    SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d)
    
    # Convert drawing to image and display
    d.FinishDrawing()
    png_data = d.GetDrawingText()
    img = Image(data=png_data)
    return img

def save_high_quality_png(smiles, title, bw=True, padding=0.05):
    """
    Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule.

    Parameters:
    - smiles (str): The SMILES string of the molecule to visualize.
    - token_importance (list): List of importance scores for each atom.
    - bw (bool): If True, renders the molecule in black and white.
    - padding (float): Padding for molecule drawing.
    - output_file (str): Path to save the high-quality PNG file.

    Returns:
    - None
    """
    
    # Convert SMILES string to RDKit molecule object
    molecule = Chem.MolFromSmiles(smiles)
    Chem.rdDepictor.Compute2DCoords(molecule)
    
    # Get token importance scores and map to atoms
    token_importance = cls_explainer(smiles)
    atom_importance = [c[1] for c in token_importance if c[0].isalpha()]
    num_atoms = molecule.GetNumAtoms()
    atom_importance = atom_importance[:num_atoms]
    
    # Set a large canvas size for high resolution
    d = Draw.MolDraw2DCairo(1500, 1500)

    dopts = d.drawOptions()
    dopts.padding = padding    
    dopts.maxFontSize = 2000
    dopts.bondLineWidth = 5

    # Optionally set black and white palette
    if bw:
        d.drawOptions().useBWAtomPalette()
    
    # Generate and display a similarity map based on atom importance scores
    SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d)

    # Draw molecule with color highlights
    d.FinishDrawing()
    
    # Save to PNG file with high quality
    with open(f"{title}.png", "wb") as png_file:
        png_file.write(d.GetDrawingText())
    
    print(f"High-quality PNG file saved as {title}.png")
    d.FinishDrawing()
    png_data = d.GetDrawingText()
    img = Image(data=png_data)
    return img