Spaces:
Runtime error
Runtime error
from rdkit import Chem | |
from rdkit.Chem import Draw | |
from rdkit.Chem.Draw import SimilarityMaps | |
from IPython.display import SVG | |
import io | |
from PIL import Image | |
import numpy as np | |
import rdkit | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
from transformers_interpret import SequenceClassificationExplainer | |
from transformers import pipeline | |
model_name = "FartLabs/FART_Augmented" | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
cls_explainer = SequenceClassificationExplainer(model, tokenizer) | |
pipe = pipeline("text-classification", model=model_name) | |
def get_taste_from_smiles(smiles): | |
# Original output | |
output = pipe(smiles) | |
# Mapping of labels to tastes | |
taste_labels = ['BITTER', 'SOUR', 'SWEET', 'UMAMI', 'UNDEFINED'] | |
# Extract label and score | |
label_info = output[0] | |
label_index = int(label_info['label'].split('_')[1]) # Get the numeric part of the label | |
score = label_info['score'] | |
# Reassign label | |
new_label = taste_labels[label_index] | |
# Format the title string | |
title_string = f"{new_label} score: {score:.2f}" | |
# Output the title string | |
return title_string | |
def calculate_aspect_ratio(molecule, base_size): | |
""" | |
Calculates the canvas width and height based on the molecule's aspect ratio. | |
Parameters: | |
- molecule (Mol): RDKit molecule object. | |
- base_size (int): The base size of the canvas, typically 400. | |
Returns: | |
- (int, int): Calculated width and height for the canvas. | |
""" | |
conf = molecule.GetConformer() | |
atom_positions = [conf.GetAtomPosition(i) for i in range(molecule.GetNumAtoms())] | |
x_coords = [pos.x for pos in atom_positions] | |
y_coords = [pos.y for pos in atom_positions] | |
width = max(x_coords) - min(x_coords) | |
height = max(y_coords) - min(y_coords) | |
aspect_ratio = width / height if height > 0 else 1 | |
canvas_width = max(base_size, int(base_size * aspect_ratio)) if aspect_ratio > 1 else base_size | |
canvas_height = max(base_size, int(base_size / aspect_ratio)) if aspect_ratio < 1 else base_size | |
return canvas_width, canvas_height | |
def visualize_gradients(smiles, bw=True, padding=0.05): | |
""" | |
Visualizes atom-wise gradients or importance scores for a given molecule | |
based on the SMILES representation as a similarity map. | |
Parameters: | |
- smiles (str): The SMILES string of the molecule to visualize. | |
- bw (bool): If True, renders the molecule in black and white (default is False). | |
Returns: | |
- None: Displays the generated similarity map in the notebook. | |
""" | |
print(get_taste_from_smiles(smiles)) | |
# Convert SMILES string to RDKit molecule object | |
molecule = Chem.MolFromSmiles(smiles) | |
Chem.rdDepictor.Compute2DCoords(molecule) | |
# Set up canvas size based on aspect ratio | |
base_size = 400 | |
width, height = calculate_aspect_ratio(molecule, base_size) | |
d = Draw.MolDraw2DCairo(width, height) | |
#Draw.SetACS1996Mode(d.drawOptions(),Draw.MeanBondLength(molecule)) | |
d.drawOptions().padding = padding | |
# Optionally set black and white palette | |
if bw: | |
d.drawOptions().useBWAtomPalette() | |
# Get token importance scores and map to atoms | |
token_importance = cls_explainer(smiles) | |
atom_importance = [c[1] for c in token_importance if c[0].isalpha()] | |
num_atoms = molecule.GetNumAtoms() | |
atom_importance = atom_importance[:num_atoms] | |
# Generate and display a similarity map based on atom importance scores | |
SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d) | |
# Convert drawing to image and display | |
d.FinishDrawing() | |
png_data = d.GetDrawingText() | |
img = Image(data=png_data) | |
return img | |
def save_high_quality_png(smiles, title, bw=True, padding=0.05): | |
""" | |
Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule. | |
Parameters: | |
- smiles (str): The SMILES string of the molecule to visualize. | |
- token_importance (list): List of importance scores for each atom. | |
- bw (bool): If True, renders the molecule in black and white. | |
- padding (float): Padding for molecule drawing. | |
- output_file (str): Path to save the high-quality PNG file. | |
Returns: | |
- None | |
""" | |
# Convert SMILES string to RDKit molecule object | |
molecule = Chem.MolFromSmiles(smiles) | |
Chem.rdDepictor.Compute2DCoords(molecule) | |
# Get token importance scores and map to atoms | |
token_importance = cls_explainer(smiles) | |
atom_importance = [c[1] for c in token_importance if c[0].isalpha()] | |
num_atoms = molecule.GetNumAtoms() | |
atom_importance = atom_importance[:num_atoms] | |
# Set a large canvas size for high resolution | |
d = Draw.MolDraw2DCairo(1500, 1500) | |
dopts = d.drawOptions() | |
dopts.padding = padding | |
dopts.maxFontSize = 2000 | |
dopts.bondLineWidth = 5 | |
# Optionally set black and white palette | |
if bw: | |
d.drawOptions().useBWAtomPalette() | |
# Generate and display a similarity map based on atom importance scores | |
SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d) | |
# Draw molecule with color highlights | |
d.FinishDrawing() | |
# Save to PNG file with high quality | |
with open(f"{title}.png", "wb") as png_file: | |
png_file.write(d.GetDrawingText()) | |
print(f"High-quality PNG file saved as {title}.png") | |
d.FinishDrawing() | |
png_data = d.GetDrawingText() | |
img = Image(data=png_data) | |
return img |