diff --git "a/examples/gene_classification.ipynb" "b/examples/gene_classification.ipynb"
deleted file mode 100644--- "a/examples/gene_classification.ipynb"
+++ /dev/null
@@ -1,1251 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "08f41458-5304-48c5-9e92-f9b56ab052c4",
- "metadata": {},
- "source": [
- "## Geneformer Fine-Tuning for Classification of Dosage-Sensitive vs. -Insensitive Transcription Factors (TFs)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "79539e95-2c9c-4162-835c-f0d158abb15d",
- "metadata": {},
- "source": [
- "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "51b4852a-9f03-4bc3-ba33-79eaa4582d50",
- "metadata": {},
- "source": [
- "### Train gene classifier with 5-fold cross-validation:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "58d59e09-5e6c-4fba-ba2b-3aee103869fd",
- "metadata": {},
- "outputs": [],
- "source": [
- "import datetime\n",
- "import pickle\n",
- "from geneformer import Classifier\n",
- "\n",
- "current_date = datetime.datetime.now()\n",
- "datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.hour:02d}{current_date.minute:02d}{current_date.second:02d}\"\n",
- "datestamp_min = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n",
- "\n",
- "output_prefix = \"tf_dosage_sens_test\"\n",
- "output_dir = f\"/path/to/output_dir/{datestamp}\"\n",
- "!mkdir $output_dir"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "9e33942f-39e4-4db4-a3de-5949bed9fa5d",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/blob/main/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sensitivity_TFs.pickle\n",
- "with open(\"/path/to/dosage_sensitivity_TFs.pickle\", \"rb\") as fp:\n",
- " gene_class_dict = pickle.load(fp)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "f4053ee9-3506-4c97-b544-8d667f0adfab",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Hyperparameter tuning is highly recommended for optimal results. No training_args provided; using default hyperparameters.\n"
- ]
- }
- ],
- "source": [
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
- "# (otherwise the Classifier will use the current default model dictionary)\n",
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
- "cc = Classifier(classifier=\"gene\",\n",
- " gene_class_dict = gene_class_dict,\n",
- " max_ncells = 10_000,\n",
- " freeze_layers = 4,\n",
- " num_crossval_splits = 5,\n",
- " forward_batch_size=200,\n",
- " nproc=16)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "e4855e53-1cd7-4af0-b786-02b6c0e55f8c",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "6a3f7bcf2a314368b00f49c74a775571",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Saving the dataset (0/1 shards): 0%| | 0/33558 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "# Example input_data_file for 30M model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/gene_classification/dosage_sensitive_tfs/gc-30M_sample50k.dataset\n",
- "cc.prepare_data(input_data_file=\"/path/to/gc-30M_sample50k.dataset\",\n",
- " output_directory=output_dir,\n",
- " output_prefix=output_prefix)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "3a1b2c1b-36e0-4b92-9e1d-33b3328f41a2",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "8ca67c71086d44f480fa241da36139ce",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "0it [00:00, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "****** Validation split: 1/5 ******\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "7059e063e1f3481eabf425ba82150079",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Filter (num_proc=16): 0%| | 0/33558 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "7c2c27cd9ecd44f7b2d7c7fafd587cfb",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Filter (num_proc=16): 0%| | 0/33558 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "ed07a7fd5d4046869f1983250134cef0",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map (num_proc=16): 0%| | 0/10000 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "430c9612c7984be8ba367e96e0ed5de3",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map (num_proc=16): 0%| | 0/10000 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Some weights of BertForTokenClassification were not initialized from the model checkpoint at /gladstone/theodoris/home/ctheodoris/Geneformer and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
- "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
- "/gladstone/theodoris/home/ctheodoris/Geneformer/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
- " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "
\n",
- " \n",
- "
\n",
- " [834/834 02:37, Epoch 1/1]\n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " Step | \n",
- " Training Loss | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 83 | \n",
- " 0.729100 | \n",
- "
\n",
- " \n",
- " 166 | \n",
- " 0.667600 | \n",
- "
\n",
- " \n",
- " 249 | \n",
- " 0.553100 | \n",
- "
\n",
- " \n",
- " 332 | \n",
- " 0.409100 | \n",
- "
\n",
- " \n",
- " 415 | \n",
- " 0.294300 | \n",
- "
\n",
- " \n",
- " 498 | \n",
- " 0.197000 | \n",
- "
\n",
- " \n",
- " 581 | \n",
- " 0.138300 | \n",
- "
\n",
- " \n",
- " 664 | \n",
- " 0.099900 | \n",
- "
\n",
- " \n",
- " 747 | \n",
- " 0.083700 | \n",
- "
\n",
- " \n",
- " 830 | \n",
- " 0.072300 | \n",
- "
\n",
- " \n",
- "
"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "****** Validation split: 2/5 ******\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "d186836393d84c19b9c0dffafb31a09c",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Filter (num_proc=16): 0%| | 0/33558 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "26cb17f7d5b7440192ed7ada0070fa7d",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Filter (num_proc=16): 0%| | 0/33558 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "9259ec2ea7db4203bb361b1b7de3773f",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map (num_proc=16): 0%| | 0/10000 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "ee45940e40b94f83a431fc1bd8081b8b",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map (num_proc=16): 0%| | 0/10000 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Some weights of BertForTokenClassification were not initialized from the model checkpoint at /gladstone/theodoris/home/ctheodoris/Geneformer and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
- "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
- "/gladstone/theodoris/home/ctheodoris/Geneformer/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
- " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " \n",
- "
\n",
- " [834/834 02:34, Epoch 1/1]\n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " Step | \n",
- " Training Loss | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 83 | \n",
- " 0.695400 | \n",
- "
\n",
- " \n",
- " 166 | \n",
- " 0.634600 | \n",
- "
\n",
- " \n",
- " 249 | \n",
- " 0.540200 | \n",
- "
\n",
- " \n",
- " 332 | \n",
- " 0.414800 | \n",
- "
\n",
- " \n",
- " 415 | \n",
- " 0.298500 | \n",
- "
\n",
- " \n",
- " 498 | \n",
- " 0.199100 | \n",
- "
\n",
- " \n",
- " 581 | \n",
- " 0.133200 | \n",
- "
\n",
- " \n",
- " 664 | \n",
- " 0.096300 | \n",
- "
\n",
- " \n",
- " 747 | \n",
- " 0.078100 | \n",
- "
\n",
- " \n",
- " 830 | \n",
- " 0.068100 | \n",
- "
\n",
- " \n",
- "
"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "****** Validation split: 3/5 ******\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "93e9c12bc6e243b39224994add37ce21",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Filter (num_proc=16): 0%| | 0/33558 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "dc429098c2a14f00be1e5921cde897dc",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Filter (num_proc=16): 0%| | 0/33558 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "fcd68a1b1f114cc7aeafb8f1fc3a26d0",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map (num_proc=16): 0%| | 0/10000 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "51d522a3d10142518022f7a024cc0398",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map (num_proc=16): 0%| | 0/10000 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Some weights of BertForTokenClassification were not initialized from the model checkpoint at /gladstone/theodoris/home/ctheodoris/Geneformer and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
- "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
- "/gladstone/theodoris/home/ctheodoris/Geneformer/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
- " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " \n",
- "
\n",
- " [834/834 02:35, Epoch 1/1]\n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " Step | \n",
- " Training Loss | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 83 | \n",
- " 0.708600 | \n",
- "
\n",
- " \n",
- " 166 | \n",
- " 0.656300 | \n",
- "
\n",
- " \n",
- " 249 | \n",
- " 0.553600 | \n",
- "
\n",
- " \n",
- " 332 | \n",
- " 0.430600 | \n",
- "
\n",
- " \n",
- " 415 | \n",
- " 0.300000 | \n",
- "
\n",
- " \n",
- " 498 | \n",
- " 0.202900 | \n",
- "
\n",
- " \n",
- " 581 | \n",
- " 0.144700 | \n",
- "
\n",
- " \n",
- " 664 | \n",
- " 0.109900 | \n",
- "
\n",
- " \n",
- " 747 | \n",
- " 0.096000 | \n",
- "
\n",
- " \n",
- " 830 | \n",
- " 0.086700 | \n",
- "
\n",
- " \n",
- "
"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "****** Validation split: 4/5 ******\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "1a9cebe980534274907ae3858a706c37",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Filter (num_proc=16): 0%| | 0/33558 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "b5596a2de83f477facdcf484c4b3cfc9",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Filter (num_proc=16): 0%| | 0/33558 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "7e3be2a6e2084240b6f657964466ccf2",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map (num_proc=16): 0%| | 0/10000 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "623bdd4b2755406396eb21332db13c3a",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map (num_proc=16): 0%| | 0/10000 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Some weights of BertForTokenClassification were not initialized from the model checkpoint at /gladstone/theodoris/home/ctheodoris/Geneformer and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
- "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
- "/gladstone/theodoris/home/ctheodoris/Geneformer/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
- " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " \n",
- "
\n",
- " [834/834 02:35, Epoch 1/1]\n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " Step | \n",
- " Training Loss | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 83 | \n",
- " 0.697500 | \n",
- "
\n",
- " \n",
- " 166 | \n",
- " 0.632000 | \n",
- "
\n",
- " \n",
- " 249 | \n",
- " 0.524600 | \n",
- "
\n",
- " \n",
- " 332 | \n",
- " 0.394300 | \n",
- "
\n",
- " \n",
- " 415 | \n",
- " 0.264700 | \n",
- "
\n",
- " \n",
- " 498 | \n",
- " 0.180100 | \n",
- "
\n",
- " \n",
- " 581 | \n",
- " 0.128300 | \n",
- "
\n",
- " \n",
- " 664 | \n",
- " 0.094200 | \n",
- "
\n",
- " \n",
- " 747 | \n",
- " 0.082200 | \n",
- "
\n",
- " \n",
- " 830 | \n",
- " 0.078500 | \n",
- "
\n",
- " \n",
- "
"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "****** Validation split: 5/5 ******\n",
- "\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "455067153dc145cba4e3cfdc63f129cc",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Filter (num_proc=16): 0%| | 0/33558 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "e664d317146c4edea145e2d80efd6960",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Filter (num_proc=16): 0%| | 0/33558 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "30e3c07d77d14bb3a483ad48570d8b3a",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map (num_proc=16): 0%| | 0/10000 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "fa63cf84534a435c938de556aba7da5c",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map (num_proc=16): 0%| | 0/10000 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Some weights of BertForTokenClassification were not initialized from the model checkpoint at /gladstone/theodoris/home/ctheodoris/Geneformer and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
- "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
- "/gladstone/theodoris/home/ctheodoris/Geneformer/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
- " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " \n",
- "
\n",
- " [834/834 02:35, Epoch 1/1]\n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " Step | \n",
- " Training Loss | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 83 | \n",
- " 0.711400 | \n",
- "
\n",
- " \n",
- " 166 | \n",
- " 0.644000 | \n",
- "
\n",
- " \n",
- " 249 | \n",
- " 0.535900 | \n",
- "
\n",
- " \n",
- " 332 | \n",
- " 0.395400 | \n",
- "
\n",
- " \n",
- " 415 | \n",
- " 0.275400 | \n",
- "
\n",
- " \n",
- " 498 | \n",
- " 0.193600 | \n",
- "
\n",
- " \n",
- " 581 | \n",
- " 0.129300 | \n",
- "
\n",
- " \n",
- " 664 | \n",
- " 0.093300 | \n",
- "
\n",
- " \n",
- " 747 | \n",
- " 0.070000 | \n",
- "
\n",
- " \n",
- " 830 | \n",
- " 0.067100 | \n",
- "
\n",
- " \n",
- "
"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 6 layer 30M Geneformer model: https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors\n",
- "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
- " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
- " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
- " output_directory=output_dir,\n",
- " output_prefix=output_prefix)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "11a1329b-4968-45f3-ac7a-2438b574404e",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "image/png": "",
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "cc.plot_conf_mat(\n",
- " conf_mat_dict={\"Geneformer\": all_metrics[\"conf_matrix\"]},\n",
- " output_directory=output_dir,\n",
- " output_prefix=output_prefix,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "edf6ffd9-8b84-4d31-8b39-11959140382f",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/png": "",
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "cc.plot_roc(\n",
- " roc_metric_dict={\"Geneformer\": all_metrics[\"all_roc_metrics\"]},\n",
- " model_style_dict={\"Geneformer\": {\"color\": \"red\", \"linestyle\": \"-\"}},\n",
- " title=\"Dosage-sensitive vs -insensitive factors\",\n",
- " output_directory=output_dir,\n",
- " output_prefix=output_prefix,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "d10ac27f-8d70-400e-8a00-d0b84c1d02b4",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'conf_matrix': Dosage-sensitive TFs Dosage-insensitive TFs\n",
- " Dosage-sensitive TFs 61229.0 14801.0\n",
- " Dosage-insensitive TFs 9094.0 73907.0,\n",
- " 'macro_f1': [0.8489695337205987,\n",
- " 0.8637730998133415,\n",
- " 0.9122635701525341,\n",
- " 0.8180200155972593,\n",
- " 0.7913574275548942],\n",
- " 'acc': [0.8544562281799618,\n",
- " 0.8647275498539312,\n",
- " 0.9122812348079727,\n",
- " 0.8182044035899506,\n",
- " 0.798060129740519],\n",
- " 'all_roc_metrics': {'mean_tpr': array([0. , 0.29330305, 0.39824459, 0.48477052, 0.53910681,\n",
- " 0.58654819, 0.62233428, 0.65499297, 0.68383714, 0.7105218 ,\n",
- " 0.7331015 , 0.75404762, 0.77191402, 0.79007262, 0.80530801,\n",
- " 0.81812243, 0.83182971, 0.84348565, 0.85308334, 0.86179954,\n",
- " 0.87018186, 0.87841599, 0.88666193, 0.89398957, 0.90104605,\n",
- " 0.90768847, 0.91468381, 0.92081589, 0.92687436, 0.93170239,\n",
- " 0.93600138, 0.93963402, 0.9430781 , 0.94641134, 0.94881205,\n",
- " 0.95143243, 0.95361201, 0.95556462, 0.95766077, 0.95966244,\n",
- " 0.96118109, 0.96277551, 0.96448544, 0.96590662, 0.96726595,\n",
- " 0.96852001, 0.96991619, 0.97113487, 0.9723888 , 0.97361378,\n",
- " 0.97487929, 0.97591807, 0.97725326, 0.97856005, 0.97952476,\n",
- " 0.98071045, 0.98164245, 0.98264028, 0.98393822, 0.9850845 ,\n",
- " 0.98620898, 0.9872157 , 0.98857151, 0.98954745, 0.99058733,\n",
- " 0.99138259, 0.99226871, 0.99306583, 0.99380789, 0.99461065,\n",
- " 0.99527049, 0.99592002, 0.99655526, 0.99691174, 0.99757778,\n",
- " 0.9978895 , 0.99816814, 0.99852539, 0.99874352, 0.99896924,\n",
- " 0.99925024, 0.9993954 , 0.99949426, 0.99964604, 0.99974177,\n",
- " 0.99977018, 0.9998233 , 0.99984802, 0.99990114, 0.99994688,\n",
- " 0.99996108, 0.99997159, 1. , 1. , 1. ,\n",
- " 1. , 1. , 1. , 1. , 1. ]),\n",
- " 'mean_fpr': array([0. , 0.01010101, 0.02020202, 0.03030303, 0.04040404,\n",
- " 0.05050505, 0.06060606, 0.07070707, 0.08080808, 0.09090909,\n",
- " 0.1010101 , 0.11111111, 0.12121212, 0.13131313, 0.14141414,\n",
- " 0.15151515, 0.16161616, 0.17171717, 0.18181818, 0.19191919,\n",
- " 0.2020202 , 0.21212121, 0.22222222, 0.23232323, 0.24242424,\n",
- " 0.25252525, 0.26262626, 0.27272727, 0.28282828, 0.29292929,\n",
- " 0.3030303 , 0.31313131, 0.32323232, 0.33333333, 0.34343434,\n",
- " 0.35353535, 0.36363636, 0.37373737, 0.38383838, 0.39393939,\n",
- " 0.4040404 , 0.41414141, 0.42424242, 0.43434343, 0.44444444,\n",
- " 0.45454545, 0.46464646, 0.47474747, 0.48484848, 0.49494949,\n",
- " 0.50505051, 0.51515152, 0.52525253, 0.53535354, 0.54545455,\n",
- " 0.55555556, 0.56565657, 0.57575758, 0.58585859, 0.5959596 ,\n",
- " 0.60606061, 0.61616162, 0.62626263, 0.63636364, 0.64646465,\n",
- " 0.65656566, 0.66666667, 0.67676768, 0.68686869, 0.6969697 ,\n",
- " 0.70707071, 0.71717172, 0.72727273, 0.73737374, 0.74747475,\n",
- " 0.75757576, 0.76767677, 0.77777778, 0.78787879, 0.7979798 ,\n",
- " 0.80808081, 0.81818182, 0.82828283, 0.83838384, 0.84848485,\n",
- " 0.85858586, 0.86868687, 0.87878788, 0.88888889, 0.8989899 ,\n",
- " 0.90909091, 0.91919192, 0.92929293, 0.93939394, 0.94949495,\n",
- " 0.95959596, 0.96969697, 0.97979798, 0.98989899, 1. ]),\n",
- " 'all_roc_auc': [0.9373324264902606,\n",
- " 0.9410936383111078,\n",
- " 0.9635257667493496,\n",
- " 0.8903987740960708,\n",
- " 0.8781592994811886],\n",
- " 'roc_auc': 0.9141830130444975,\n",
- " 'roc_auc_sd': 0.03204329033266111}}"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "all_metrics"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "7007e45e-16c2-47a3-962c-92b9fe867bde",
- "metadata": {},
- "source": [
- "### Train gene classifier with all data:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "id": "6df82c21-937c-4563-ba6b-a52ce287f542",
- "metadata": {},
- "outputs": [],
- "source": [
- "import datetime\n",
- "import pickle\n",
- "from geneformer import Classifier\n",
- "\n",
- "current_date = datetime.datetime.now()\n",
- "datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.hour:02d}{current_date.minute:02d}{current_date.second:02d}\"\n",
- "datestamp_min = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n",
- "\n",
- "\n",
- "output_prefix = \"tf_dosage_sens_alldata\"\n",
- "output_dir = f\"/path/to/output_dir/{datestamp}\"\n",
- "!mkdir $output_dir"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "id": "f031131c-54fd-4ad1-a925-bf0846cc3235",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/blob/main/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sensitivity_TFs.pickle\n",
- "with open(\"/path/to/dosage_sensitivity_TFs.pickle\", \"rb\") as fp:\n",
- " gene_class_dict = pickle.load(fp)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "id": "cd27b15c-52d4-46a6-af8c-812c8731f82c",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Hyperparameter tuning is highly recommended for optimal results. No training_args provided; using default hyperparameters.\n"
- ]
- }
- ],
- "source": [
- "cc = Classifier(classifier=\"gene\",\n",
- " gene_class_dict = gene_class_dict,\n",
- " max_ncells = 10_000,\n",
- " freeze_layers = 4,\n",
- " num_crossval_splits = 0,\n",
- " forward_batch_size=200,\n",
- " nproc=16)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "id": "3d542bda-fbab-4d63-ab58-00d4caa996b9",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "7f77eaec105642b199a9e797fccdbf4b",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Saving the dataset (0/1 shards): 0%| | 0/33558 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/gene_classification/dosage_sensitive_tfs/gc-30M_sample50k.dataset\n",
- "cc.prepare_data(input_data_file=\"/path/to/gc-30M_sample50k.dataset\",\n",
- " output_directory=output_dir,\n",
- " output_prefix=output_prefix)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "id": "b8f421d0-c6b2-4ceb-af93-7f5276a6dfd7",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "8c8df78bcd7c4add9121d75b42ffd3ee",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Filter (num_proc=16): 0%| | 0/33558 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "102f40fd802f4b6eb48afddf8130508d",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map (num_proc=16): 0%| | 0/10000 [00:00, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "mkdir: cannot create directory ‘/gladstone/theodoris/home/ctheodoris/temp/test_suite_output/classifer_tests/240224031008/240224_geneformer_geneClassifier_tf_dosage_sens_alldata/’: File exists\n",
- "Some weights of BertForTokenClassification were not initialized from the model checkpoint at /gladstone/theodoris/home/ctheodoris/Geneformer and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
- "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
- "/gladstone/theodoris/home/ctheodoris/Geneformer/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
- " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " \n",
- "
\n",
- " [834/834 02:35, Epoch 1/1]\n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " Step | \n",
- " Training Loss | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 83 | \n",
- " 0.700600 | \n",
- "
\n",
- " \n",
- " 166 | \n",
- " 0.643100 | \n",
- "
\n",
- " \n",
- " 249 | \n",
- " 0.544700 | \n",
- "
\n",
- " \n",
- " 332 | \n",
- " 0.412900 | \n",
- "
\n",
- " \n",
- " 415 | \n",
- " 0.298600 | \n",
- "
\n",
- " \n",
- " 498 | \n",
- " 0.205700 | \n",
- "
\n",
- " \n",
- " 581 | \n",
- " 0.138900 | \n",
- "
\n",
- " \n",
- " 664 | \n",
- " 0.103200 | \n",
- "
\n",
- " \n",
- " 747 | \n",
- " 0.090000 | \n",
- "
\n",
- " \n",
- " 830 | \n",
- " 0.083100 | \n",
- "
\n",
- " \n",
- "
"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n",
- "trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\",\n",
- " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
- " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
- " output_directory=output_dir,\n",
- " output_prefix=output_prefix)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.15"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}