{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#!/usr/bin/env python\n", "# coding: utf-8" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generative Pre-Training from Molecules" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "#os.environ[\"CUDA_VISIBLE_DEVICES\"] = ['1',\"2\"]\n", "from pprint import pprint\n", "import sys\n", "sys.path.append('/home/jmwang/drugai/iupac-gpt')\n", "from tqdm import tqdm\n", "try:\n", " import iupac_gpt as gpt\n", "except ImportError:\n", " import sys\n", " sys.path.extend([\"..\"]) # Parent directory stores `smiles_gpt` package.\n", " import iupac_gpt as gpt\n", "import torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For demonstration purposes, we use only 10K subset of PubChem data made available by
\n", "[ChemBERTa](https://arxiv.org/abs/2010.09885) developers. The original model was pretrained
\n", "on the first 5M compounds with the following hyperparameters:
\n", "```python
\n", "hyperparams = {\"batch_size\": 128, \"max_epochs\": 2, \"max_length\": 512,
\n", " \"learning_rate\": 5e-4, \"weight_decay\": 0.0,
\n", " \"adam_eps\": 1e-8, \"adam_betas\": (0.9, 0.999),
\n", " \"scheduler_T_max\": 150_000, \"final_learning_rate\": 5e-8,
\n", " \"vocab_size\": 1_000, \"min_frequency\": 2, \"top_p\": 0.96,
\n", " \"n_layer\": 4, \"n_head\": 8, \"n_embd\": 512}
\n", "```
\n", "Tokenizer, model, optimizer, scheduler, and trainer hyperparameters." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "hyperparams = {\"batch_size\": 128, \"max_epochs\": 10, \"max_length\": 1280,\n", " \"learning_rate\": 5e-4, \"weight_decay\": 0.0,\n", " \"adam_eps\": 1e-8, \"adam_betas\": (0.9, 0.999),\n", " \"scheduler_T_max\": 1_000, \"final_learning_rate\": 5e-8,\n", " \"vocab_size\": 1491, \"min_frequency\": 2, \"top_p\": 0.96,\n", " \"n_layer\": 8, \"n_head\": 8, \"n_embd\": 256}" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "iupac_vocab_size: 1491\n", "training... 1491\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " 0%| | 0/144537 [00:00\"` (beginning-of-SMILES) and\n", "# `\"\"` (end-of-SMILES) special tokens. `smiles_gpt.SMILESAlphabet` stores 72 possible\n", "# characters as an initial vocabulary.\n", "device = 'gpu'\n", "train_dataloader,iupac_tokenizer = gpt.get_data_loader(is_train=1,dataset_filename = './pubchem_iupac_smile_gpt.csv')\n", "pbar = tqdm(train_dataloader) #train_dataloader.cuda()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "
\n", "for inputs in pbar:
\n", " src_label = Variable(inputs[\"labels\"].to(device))
\n", " inputs = prepare_input(inputs,device)
\n", " src = Variable(inputs[\"input_ids\"].to(device))
\n", " #self.tokenizer._convert_token_to_id
\n", " print(src[:,:].shape,src_label)
\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "tokenizer = iupac_tokenizer\n", "#start mark 2, end mark 1, pad 0" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "iupac_string = \"2-amino-4-(2-amino-3-hydroxyphenyl)-4-oxobutanoic acid\"\n", "iupac_encoded = tokenizer(iupac_string)\n", "iupac_encoded['input_ids'] = [2]+iupac_encoded['input_ids']" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "iupac_merges = [tokenizer.decode(i) for i in iupac_encoded['input_ids']]\n", "#iupac_encoded['attention_mask']" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2, 5, 150, 165, 150, 7, 150, 154, 5, 150, 165, 150, 6, 150, 174, 158, 153, 150, 7, 150, 166, 173, 160, 169, 198, 1]\n", "['', '2', '-', 'amino', '-', '4', '-', '(', '2', '-', 'amino', '-', '3', '-', 'hydroxy', 'phenyl', ')', '-', '4', '-', 'oxo', 'but', 'an', 'o', 'ic acid', '']\n" ] } ], "source": [ "print(iupac_encoded['input_ids'])\n", "print(iupac_merges)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2 1 1491\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/jmwang/drugai/iupac-gpt/iupac_gpt/iupac_dataset.py:103: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " input_ids = torch.tensor(input_ids)\n", "/home/jmwang/drugai/iupac-gpt/iupac_gpt/iupac_tokenization_iupac.py:44: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch[k] = pad_sequence([torch.tensor(r[k]) for r in records],\n", " 0%| | 0/144537 [00:12\n", "
\n", "Now we load HuggingFace
\n", "[`GPT2LMHeadModel`](https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel)
\n", "with the configuration composed of previously
\n", "defined model hyperparameters. The model processes mini-batch of input ids and labels, then
\n", "returns predictions and cross-entropy loss between labels and predictions." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from transformers import GPT2Config, GPT2LMHeadModel" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "config = GPT2Config(vocab_size=tokenizer.vocab_size,\n", " bos_token_id=tokenizer.unk_token_id,\n", " eos_token_id=tokenizer.eos_token_id,\n", " n_layer=hyperparams[\"n_layer\"],\n", " n_head=hyperparams[\"n_head\"],\n", " n_embd=hyperparams[\"n_embd\"],\n", " n_positions=hyperparams[\"max_length\"],\n", " n_ctx=hyperparams[\"max_length\"])\n", "model = GPT2LMHeadModel(config)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "odel= torch.nn.DataParallel(model.cuda(),device_ids=gpus,output_device=gpus[0])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "odict_keys(['loss', 'logits', 'past_key_values'])\n" ] } ], "source": [ "outputs = model(**batch)\n", "print(outputs.keys())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "'loss', 'logits', 'past_key_values']
\n", "## Trainer
\n", "
\n", "GPT-2 is trained with autoregressive language modeling objective:
\n", "$$
\n", "P(\\boldsymbol{s}) = P(s_1) \\cdot P(s_2 | s_1) \\cdots P(s_T | s_1, \\ldots, s_{T-1}) =
\n", "\\prod_{t=1}^{T} P(s_t | s_{j < t}),
\n", "$$
\n", "where $\\boldsymbol{s}$ is a tokenized (encoded) SMILES string, $s_t$ is a token from pretrained
\n", "vocabulary $\\mathcal{V}$.
\n", "
\n", "We use `pytorch_lightning.Trainer` to train GPT-2. Since `Trainer` requires lightning modules,
\n", "we import our
\n", "[`smiles_gpt.GPT2LitModel`](https://github.com/sanjaradylov/smiles-gpt/blob/master/smiles_gpt/language_modeling.py#L10)
\n", "wrapper that implements training phases for
\n", "`GPT2LMHeadModel`, configures an `Adam` optimizer with `CosineAnnealingLR` scheduler, and
\n", "logs average perplexity every epoch." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In[8]:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from pytorch_lightning import Trainer\n", "from pytorch_lightning.callbacks.early_stopping import EarlyStopping" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "checkpoint = \"../checkpoints/iupac\"" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "ename": "MisconfigurationException", "evalue": "`Trainer(strategy='ddp')` or `Trainer(accelerator='ddp')` is not compatible with an interactive environment. Run your code as a script, or choose one of the compatible backends: dp, ddp_spawn, ddp_sharded_spawn, tpu_spawn. In case you are spawning processes yourself, make sure to include the Trainer creation inside the worker function.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mMisconfigurationException\u001b[0m Traceback (most recent call last)", "Input \u001b[0;32mIn [14]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[43mTrainer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mgpus\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgpus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhyperparams\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmax_epochs\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43mEarlyStopping\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mppl\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m#[EarlyStopping(\"ppl\", 0.2, 2)]\u001b[39;49;00m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mauto_lr_find\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Set to True to search for optimal learning rate.\u001b[39;49;00m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mauto_scale_batch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Set to True to scale batch size\u001b[39;49;00m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# accelerator=\"dp\" # Uncomment for GPU training.\u001b[39;49;00m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgpu\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m#devices=4,\u001b[39;49;00m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mstrategy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mddp\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\n\u001b[1;32m 10\u001b[0m \u001b[43m)\u001b[49m\n\u001b[1;32m 11\u001b[0m lit_model \u001b[38;5;241m=\u001b[39m gpt\u001b[38;5;241m.\u001b[39mGPT2LitModel(\n\u001b[1;32m 12\u001b[0m model,\n\u001b[1;32m 13\u001b[0m batch_size\u001b[38;5;241m=\u001b[39mhyperparams[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 19\u001b[0m scheduler_T_max\u001b[38;5;241m=\u001b[39mhyperparams[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mscheduler_T_max\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 20\u001b[0m save_model_every\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, checkpoint\u001b[38;5;241m=\u001b[39mcheckpoint)\n\u001b[1;32m 21\u001b[0m trainer\u001b[38;5;241m.\u001b[39mfit(lit_model, train_dataloader)\n", "File \u001b[0;32m~/anaconda3/envs/smiles-gpt/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/env_vars_connector.py:38\u001b[0m, in \u001b[0;36m_defaults_from_env_vars..insert_env_defaults\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 35\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mdict\u001b[39m(\u001b[38;5;28mlist\u001b[39m(env_variables\u001b[38;5;241m.\u001b[39mitems()) \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mlist\u001b[39m(kwargs\u001b[38;5;241m.\u001b[39mitems()))\n\u001b[1;32m 37\u001b[0m \u001b[38;5;66;03m# all args were already moved to kwargs\u001b[39;00m\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/anaconda3/envs/smiles-gpt/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:431\u001b[0m, in \u001b[0;36mTrainer.__init__\u001b[0;34m(self, logger, checkpoint_callback, enable_checkpointing, callbacks, default_root_dir, gradient_clip_val, gradient_clip_algorithm, process_position, num_nodes, num_processes, devices, gpus, auto_select_gpus, tpu_cores, ipus, log_gpu_memory, progress_bar_refresh_rate, enable_progress_bar, overfit_batches, track_grad_norm, check_val_every_n_epoch, fast_dev_run, accumulate_grad_batches, max_epochs, min_epochs, max_steps, min_steps, max_time, limit_train_batches, limit_val_batches, limit_test_batches, limit_predict_batches, val_check_interval, flush_logs_every_n_steps, log_every_n_steps, accelerator, strategy, sync_batchnorm, precision, enable_model_summary, weights_summary, weights_save_path, num_sanity_val_steps, resume_from_checkpoint, profiler, benchmark, deterministic, reload_dataloaders_every_n_epochs, reload_dataloaders_every_epoch, auto_lr_find, replace_sampler_ddp, detect_anomaly, auto_scale_batch_size, prepare_data_per_node, plugins, amp_backend, amp_level, move_metrics_to_cpu, multiple_trainloader_mode, stochastic_weight_avg, terminate_on_nan)\u001b[0m\n\u001b[1;32m 428\u001b[0m \u001b[38;5;66;03m# init connectors\u001b[39;00m\n\u001b[1;32m 429\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_connector \u001b[38;5;241m=\u001b[39m DataConnector(\u001b[38;5;28mself\u001b[39m, multiple_trainloader_mode)\n\u001b[0;32m--> 431\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_accelerator_connector \u001b[38;5;241m=\u001b[39m \u001b[43mAcceleratorConnector\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 432\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_processes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 433\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevices\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 434\u001b[0m \u001b[43m \u001b[49m\u001b[43mtpu_cores\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 435\u001b[0m \u001b[43m \u001b[49m\u001b[43mipus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 436\u001b[0m \u001b[43m \u001b[49m\u001b[43maccelerator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 437\u001b[0m \u001b[43m \u001b[49m\u001b[43mstrategy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 438\u001b[0m \u001b[43m \u001b[49m\u001b[43mgpus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 439\u001b[0m \u001b[43m \u001b[49m\u001b[43mgpu_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 440\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_nodes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 441\u001b[0m \u001b[43m \u001b[49m\u001b[43msync_batchnorm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 442\u001b[0m \u001b[43m \u001b[49m\u001b[43mbenchmark\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 443\u001b[0m \u001b[43m \u001b[49m\u001b[43mreplace_sampler_ddp\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 444\u001b[0m \u001b[43m \u001b[49m\u001b[43mdeterministic\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 445\u001b[0m \u001b[43m \u001b[49m\u001b[43mprecision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 446\u001b[0m \u001b[43m \u001b[49m\u001b[43mamp_backend\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 447\u001b[0m \u001b[43m \u001b[49m\u001b[43mamp_level\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 448\u001b[0m \u001b[43m \u001b[49m\u001b[43mplugins\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 449\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 450\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlogger_connector \u001b[38;5;241m=\u001b[39m LoggerConnector(\u001b[38;5;28mself\u001b[39m, log_gpu_memory)\n\u001b[1;32m 451\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_callback_connector \u001b[38;5;241m=\u001b[39m CallbackConnector(\u001b[38;5;28mself\u001b[39m)\n", "File \u001b[0;32m~/anaconda3/envs/smiles-gpt/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:164\u001b[0m, in \u001b[0;36mAcceleratorConnector.__init__\u001b[0;34m(self, num_processes, devices, tpu_cores, ipus, accelerator, strategy, gpus, gpu_ids, num_nodes, sync_batchnorm, benchmark, replace_sampler_ddp, deterministic, precision, amp_type, amp_level, plugins)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mselect_accelerator_type()\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 164\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_set_training_type_plugin\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mset_distributed_mode()\n", "File \u001b[0;32m~/anaconda3/envs/smiles-gpt/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:311\u001b[0m, in \u001b[0;36mAcceleratorConnector._set_training_type_plugin\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_training_type_plugin \u001b[38;5;241m=\u001b[39m TrainingTypePluginsRegistry\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy)\n\u001b[1;32m 310\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy, \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m--> 311\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mset_distributed_mode\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstrategy\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy, TrainingTypePlugin):\n\u001b[1;32m 313\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_training_type_plugin \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\n", "File \u001b[0;32m~/anaconda3/envs/smiles-gpt/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:902\u001b[0m, in \u001b[0;36mAcceleratorConnector.set_distributed_mode\u001b[0;34m(self, strategy)\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_distrib_type \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 901\u001b[0m \u001b[38;5;66;03m# finished configuring self._distrib_type, check ipython environment\u001b[39;00m\n\u001b[0;32m--> 902\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcheck_interactive_compatibility\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 904\u001b[0m \u001b[38;5;66;03m# for DDP overwrite nb processes by requested GPUs\u001b[39;00m\n\u001b[1;32m 905\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_device_type \u001b[38;5;241m==\u001b[39m DeviceType\u001b[38;5;241m.\u001b[39mGPU \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_distrib_type \u001b[38;5;129;01min\u001b[39;00m (\n\u001b[1;32m 906\u001b[0m DistributedType\u001b[38;5;241m.\u001b[39mDDP,\n\u001b[1;32m 907\u001b[0m DistributedType\u001b[38;5;241m.\u001b[39mDDP_SPAWN,\n\u001b[1;32m 908\u001b[0m ):\n", "File \u001b[0;32m~/anaconda3/envs/smiles-gpt/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:943\u001b[0m, in \u001b[0;36mAcceleratorConnector.check_interactive_compatibility\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 940\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpytorch_lightning\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutilities\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _IS_INTERACTIVE\n\u001b[1;32m 942\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _IS_INTERACTIVE \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_distrib_type \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_distrib_type\u001b[38;5;241m.\u001b[39mis_interactive_compatible():\n\u001b[0;32m--> 943\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MisconfigurationException(\n\u001b[1;32m 944\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`Trainer(strategy=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_distrib_type\u001b[38;5;241m.\u001b[39mvalue\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m)` or\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 945\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m `Trainer(accelerator=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_distrib_type\u001b[38;5;241m.\u001b[39mvalue\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m)` is not compatible with an interactive\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 946\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m environment. Run your code as a script, or choose one of the compatible backends:\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 947\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(DistributedType\u001b[38;5;241m.\u001b[39minteractive_compatible_types())\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 948\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m In case you are spawning processes yourself, make sure to include the Trainer\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 949\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m creation inside the worker function.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 950\u001b[0m )\n", "\u001b[0;31mMisconfigurationException\u001b[0m: `Trainer(strategy='ddp')` or `Trainer(accelerator='ddp')` is not compatible with an interactive environment. Run your code as a script, or choose one of the compatible backends: dp, ddp_spawn, ddp_sharded_spawn, tpu_spawn. In case you are spawning processes yourself, make sure to include the Trainer creation inside the worker function." ] } ], "source": [ "trainer = Trainer(\n", " gpus=gpus,\n", " max_epochs=hyperparams[\"max_epochs\"],\n", " callbacks=[EarlyStopping(\"ppl\", 0.1, 3)], #[EarlyStopping(\"ppl\", 0.2, 2)]\n", " auto_lr_find=False, # Set to True to search for optimal learning rate.\n", " auto_scale_batch_size=False, # Set to True to scale batch size\n", " # accelerator=\"dp\" # Uncomment for GPU training.\n", " accelerator=\"gpu\", #devices=4,\n", " strategy=\"dp\"\n", ")\n", "lit_model = gpt.GPT2LitModel(\n", " model,\n", " batch_size=hyperparams[\"batch_size\"],\n", " learning_rate=hyperparams[\"learning_rate\"],\n", " final_learning_rate=hyperparams[\"final_learning_rate\"],\n", " weight_decay=hyperparams[\"weight_decay\"],\n", " adam_eps=hyperparams[\"adam_eps\"],\n", " adam_betas=hyperparams[\"adam_betas\"],\n", " scheduler_T_max=hyperparams[\"scheduler_T_max\"],\n", " save_model_every=1, checkpoint=checkpoint)\n", "trainer.fit(lit_model, train_dataloader)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "odel.module.save_pretrained('./pretrained')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.save_pretrained('./pretrained')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Interpretability
\n", "
\n", "[BertViz](https://github.com/jessevig/bertviz) inspects attention heads of transformers
\n", "capturing specific patterns in data. Each head can be representative of some syntactic
\n", "or short-/long-term relationships between tokens." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In[9]:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('attention', 'tokens', 'sentence_b_start', 'prettify_tokens', 'layer', 'heads', 'encoder_attention', 'decoder_attention', 'cross_attention', 'encoder_tokens', 'decoder_tokens', 'include_layers', 'html_action', 'slice_a', 'slice_b', 'vis_id', 'options', 'select_html', 'vis_html', 'd', 'attn_seq_len_left', 'attn_seq_len_right', 'params', '__location__', 'vis_js', 'html1', 'html2', 'html3', 'script', 'head_html')\n" ] } ], "source": [ "import torch\n", "from bertviz import head_view\n", "print(head_view.__code__.co_varnames)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", "
\n", " \n", " Layer: \n", " \n", " \n", "
\n", "
\n", " \n", "\n" ], "text/plain": [ "" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input_ids_list = iupac_encoded['input_ids']\n", "model = GPT2LMHeadModel.from_pretrained(checkpoint, output_attentions=True)\n", "attention = model(torch.LongTensor(input_ids_list[1:-1]))[-1]\n", "tokens = [tokenizer.decode(i) for i in input_ids_list]\n", "#print(input_ids_list,attention,tokens)\n", "# Don't worry if a snippet is not displayed---just rerun this cell.\n", "\n", "a=head_view(attention = attention, tokens=tokens[1:-1],html_action='return')\n", "\n", "with open(\"iupac_head_view.html\", 'w') as file:\n", " file.write(a.data)\n", "a" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('attention', 'tokens', 'sentence_b_start', 'prettify_tokens', 'display_mode', 'encoder_attention', 'decoder_attention', 'cross_attention', 'encoder_tokens', 'decoder_tokens', 'include_layers', 'include_heads', 'html_action', 'n_heads', 'slice_a', 'slice_b', 'vis_id', 'options', 'select_html', 'vis_html', 'd', 'attn_seq_len_left', 'attn_seq_len_right', 'params', '__location__', 'vis_js', 'html1', 'html2', 'html3', 'script', 'head_html')\n" ] } ], "source": [ "from bertviz import model_view\n", "print(model_view.__code__.co_varnames)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Don't worry if a snippet is not displayed---just rerun this cell." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", "
\n", " \n", " \n", " \n", "
\n", "
\n", " \n", "\n" ], "text/plain": [ "" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a=model_view(attention, tokens[1:-1],html_action='return')\n", "\n", "with open(\"iupac_model_view.html\", 'w') as file:\n", " file.write(a.data)\n", "a" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sampling
\n", "
\n", "Finally, we generate novel SMILES strings with top-$p$ sampling$-$i.e., sampling from the
\n", "smallest vocabulary subset $\\mathcal{V}^{(p)} \\subset \\mathcal{V}$ s.t. it takes up the most
\n", "probable tokens whose cumulative probability mass exceeds $p$, $0 < p < 1$. Model
\n", "terminates the procedure upon encountering `\"\"` or reaching maximum number
\n", "`hyperparams[\"max_length\"]`. Special tokens are eventually removed." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tqdm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.eval() # Set the base model to evaluation mode." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "generated_smiles_list = []\n", "n_generated = 30000" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for _ in tqdm.tqdm(range(n_generated)):\n", " # Generate from \"\" so that the next token is arbitrary.\n", " smiles_start = torch.LongTensor([[tokenizer.unk_token_id]])\n", " # Get generated token IDs.\n", " generated_ids = model.generate(smiles_start,\n", " max_length=hyperparams[\"max_length\"],\n", " do_sample=True,top_p=hyperparams[\"top_p\"],\n", " repetition_penalty=1.2,\n", " pad_token_id=tokenizer.eos_token_id)\n", " # Decode the IDs into tokens and remove \"\" and \"\".\n", " generated_smiles = tokenizer.decode(generated_ids[0],\n", " skip_special_tokens=True)\n", " generated_smiles_list.append(generated_smiles)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(generated_smiles_list[:10])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df2 = pd.DataFrame(generated_smiles_list, columns=['iupac']) " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df2.to_csv(\"iupacGPT2-gen30K.csv\",index=None,mode='a')" ] } ], "metadata": { "kernelspec": { "display_name": "smiles-gpt", "language": "python", "name": "smiles-gpt" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 2 }