from rdkit import Chem from rdkit.Chem import Draw from rdkit.Chem.Draw import SimilarityMaps from IPython.display import SVG import io from PIL import Image import numpy as np import rdkit from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers_interpret import SequenceClassificationExplainer from transformers import pipeline model_name = "FartLabs/FART_Augmented" model = AutoModelForSequenceClassification.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) cls_explainer = SequenceClassificationExplainer(model, tokenizer) pipe = pipeline("text-classification", model=model_name) def get_taste_from_smiles(smiles): # Original output output = pipe(smiles) # Mapping of labels to tastes taste_labels = ['BITTER', 'SOUR', 'SWEET', 'UMAMI', 'UNDEFINED'] # Extract label and score label_info = output[0] label_index = int(label_info['label'].split('_')[1]) # Get the numeric part of the label score = label_info['score'] # Reassign label new_label = taste_labels[label_index] # Format the title string title_string = f"{new_label} score: {score:.2f}" # Output the title string return title_string def calculate_aspect_ratio(molecule, base_size): """ Calculates the canvas width and height based on the molecule's aspect ratio. Parameters: - molecule (Mol): RDKit molecule object. - base_size (int): The base size of the canvas, typically 400. Returns: - (int, int): Calculated width and height for the canvas. """ conf = molecule.GetConformer() atom_positions = [conf.GetAtomPosition(i) for i in range(molecule.GetNumAtoms())] x_coords = [pos.x for pos in atom_positions] y_coords = [pos.y for pos in atom_positions] width = max(x_coords) - min(x_coords) height = max(y_coords) - min(y_coords) aspect_ratio = width / height if height > 0 else 1 canvas_width = max(base_size, int(base_size * aspect_ratio)) if aspect_ratio > 1 else base_size canvas_height = max(base_size, int(base_size / aspect_ratio)) if aspect_ratio < 1 else base_size return canvas_width, canvas_height def visualize_gradients(smiles, bw=True, padding=0.05): """ Visualizes atom-wise gradients or importance scores for a given molecule based on the SMILES representation as a similarity map. Parameters: - smiles (str): The SMILES string of the molecule to visualize. - bw (bool): If True, renders the molecule in black and white (default is False). Returns: - None: Displays the generated similarity map in the notebook. """ print(get_taste_from_smiles(smiles)) # Convert SMILES string to RDKit molecule object molecule = Chem.MolFromSmiles(smiles) Chem.rdDepictor.Compute2DCoords(molecule) # Set up canvas size based on aspect ratio base_size = 400 width, height = calculate_aspect_ratio(molecule, base_size) d = Draw.MolDraw2DCairo(width, height) #Draw.SetACS1996Mode(d.drawOptions(),Draw.MeanBondLength(molecule)) d.drawOptions().padding = padding # Optionally set black and white palette if bw: d.drawOptions().useBWAtomPalette() # Get token importance scores and map to atoms token_importance = cls_explainer(smiles) atom_importance = [c[1] for c in token_importance if c[0].isalpha()] num_atoms = molecule.GetNumAtoms() atom_importance = atom_importance[:num_atoms] # Generate and display a similarity map based on atom importance scores SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d) # Convert drawing to image and display d.FinishDrawing() png_data = d.GetDrawingText() img = Image(data=png_data) return img def save_high_quality_png(smiles, title, bw=True, padding=0.05): """ Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule. Parameters: - smiles (str): The SMILES string of the molecule to visualize. - token_importance (list): List of importance scores for each atom. - bw (bool): If True, renders the molecule in black and white. - padding (float): Padding for molecule drawing. - output_file (str): Path to save the high-quality PNG file. Returns: - None """ # Convert SMILES string to RDKit molecule object molecule = Chem.MolFromSmiles(smiles) Chem.rdDepictor.Compute2DCoords(molecule) # Get token importance scores and map to atoms token_importance = cls_explainer(smiles) atom_importance = [c[1] for c in token_importance if c[0].isalpha()] num_atoms = molecule.GetNumAtoms() atom_importance = atom_importance[:num_atoms] # Set a large canvas size for high resolution d = Draw.MolDraw2DCairo(1500, 1500) dopts = d.drawOptions() dopts.padding = padding dopts.maxFontSize = 2000 dopts.bondLineWidth = 5 # Optionally set black and white palette if bw: d.drawOptions().useBWAtomPalette() # Generate and display a similarity map based on atom importance scores SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d) # Draw molecule with color highlights d.FinishDrawing() # Save to PNG file with high quality with open(f"{title}.png", "wb") as png_file: png_file.write(d.GetDrawingText()) print(f"High-quality PNG file saved as {title}.png") d.FinishDrawing() png_data = d.GetDrawingText() img = Image(data=png_data) return img