{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-08-14T11:36:45.359507Z","iopub.status.busy":"2024-08-14T11:36:45.358807Z","iopub.status.idle":"2024-08-14T11:38:49.639719Z","shell.execute_reply":"2024-08-14T11:38:49.638484Z","shell.execute_reply.started":"2024-08-14T11:36:45.359475Z"},"trusted":true},"outputs":[],"source":["# For any HF basic activities like loading models\n","# and tokenizers for running inference\n","# upgrade is a must for the newest Gemma model\n","!pip install -q --upgrade datasets\n","!pip install -q --upgrade transformers\n","\n","# For doing efficient stuff - PEFT\n","!pip install -q --upgrade peft\n","!pip install -q --upgrade trl\n","!pip install -q bitsandbytes\n","!pip install -q accelerate\n","\n","# for logging and visualizing training progress\n","!pip install -q tensorboard\n","# If creating a new dataset, useful for creating *.jsonl files\n","!pip install -q jsonlines"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-08-14T09:14:14.730158Z","iopub.status.busy":"2024-08-14T09:14:14.729841Z","iopub.status.idle":"2024-08-14T09:15:58.023466Z","shell.execute_reply":"2024-08-14T09:15:58.022512Z","shell.execute_reply.started":"2024-08-14T09:14:14.730128Z"},"trusted":true},"outputs":[],"source":["! conda install -y gdown"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-08-14T09:17:54.015233Z","iopub.status.busy":"2024-08-14T09:17:54.014883Z","iopub.status.idle":"2024-08-14T09:19:23.458473Z","shell.execute_reply":"2024-08-14T09:19:23.457105Z","shell.execute_reply.started":"2024-08-14T09:17:54.015206Z"},"trusted":true},"outputs":[],"source":["import itertools\n","import time\n","import warnings\n","from peft import LoraConfig, get_peft_model\n","from transformers import BertForSequenceClassification, AutoTokenizer, TrainingArguments, Trainer\n","from matplotlib import pyplot as plt\n","from datasets import load_dataset\n","import torch\n","from tqdm import tqdm\n","import numpy as np\n","from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, matthews_corrcoef, roc_auc_score\n","import huggingface_hub\n","\n","huggingface_hub.login(token=hf_token)\n","\n","# Suppress warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","# Layer configurations\n","attention_plus_feed_forward = [\n"," \"bert.encoder.layer.0.attention.self.query\",\n"," \"bert.encoder.layer.0.attention.self.key\",\n"," \"bert.encoder.layer.0.attention.self.value\",\n"," \"bert.encoder.layer.0.attention.output.dense\",\n"," \"bert.encoder.layer.0.intermediate.dense\",\n"," \"bert.encoder.layer.0.output.dense\",\n"," \"bert.encoder.layer.1.attention.self.query\",\n"," \"bert.encoder.layer.1.attention.self.key\",\n"," \"bert.encoder.layer.1.attention.self.value\",\n"," \"bert.encoder.layer.1.attention.output.dense\",\n"," \"bert.encoder.layer.1.intermediate.dense\",\n"," \"bert.encoder.layer.1.output.dense\"\n","]\n","\n","\n","tokenizer = AutoTokenizer.from_pretrained('zhihan1996/DNA_bert_6')\n","# Function to preprocess the dataset\n","def preprocess_function(examples):\n"," try:\n"," return tokenizer(\n"," examples['sequence'],\n"," padding='max_length',\n"," truncation=True,\n"," max_length=512\n"," )\n"," except KeyError:\n"," return tokenizer(\n"," examples['Sequence'],\n"," padding='max_length',\n"," truncation=True,\n"," max_length=512\n"," )\n","\n","\n","def add_labels(examples):\n"," try:\n"," examples['labels'] = examples['label']\n"," return examples\n"," except KeyError:\n"," examples['labels'] = examples['Label']\n"," return examples\n","\n","def create_task_dataset(task_name):\n"," if task_name == 'tfbs':\n"," return load_dataset('csv', data_files='/kaggle/working/tfbs.csv', split='train[0:10000]'), load_dataset('csv', data_files='/kaggle/working/tfbs.csv', split='train[10001:13122]')\n","\n"," elif task_name == 'dnasplice':\n"," return load_dataset('csv', data_files='/kaggle/working/dnasplice.csv', split='train[0:10000]'), load_dataset('csv', data_files='/kaggle/working/dnasplice.csv', split='train[10001:13122]')\n","\n"," elif task_name == 'dnaprom':\n"," return load_dataset('csv', data_files='/kaggle/working/dnaprom.csv', split='train[0:10000]'), load_dataset('csv', data_files='/kaggle/working/dnaprom.csv', split='train[10001:13122]')\n","\n"," else:\n"," raise ValueError(f\"Unknown task: {task_name}\")\n","\n","def create_dataset_maps(train_dataset, test_dataset):\n"," train_dataset = train_dataset.map(preprocess_function, batched=True)\n"," train_dataset = train_dataset.map(add_labels)\n"," test_dataset = test_dataset.map(preprocess_function, batched=True)\n"," test_dataset = test_dataset.map(add_labels)\n"," return train_dataset, test_dataset\n","\n","def train_model(train_dataset, test_dataset, model, task, model_name, config_name):\n"," def specificity_score(y_true, y_pred):\n"," true_negatives = np.sum((y_pred == 0) & (y_true == 0))\n"," false_positives = np.sum((y_pred == 1) & (y_true == 0))\n"," specificity = true_negatives / (true_negatives + false_positives + np.finfo(float).eps)\n"," return specificity\n","\n"," def compute_metrics(eval_pred):\n"," logits, labels = eval_pred\n"," predictions = np.argmax(logits, axis=-1)\n"," y_pred = logits[:, 1]\n","\n"," accuracy = accuracy_score(labels, predictions)\n"," recall = recall_score(labels, predictions)\n"," specificity = specificity_score(labels, predictions)\n"," mcc = matthews_corrcoef(labels, predictions)\n"," roc_auc = roc_auc_score(labels, y_pred)\n"," precision = precision_score(labels, predictions)\n"," f1 = f1_score(labels, predictions)\n","\n"," true_pos = np.sum((predictions == 1) & (labels == 1))\n"," true_neg = np.sum((predictions == 0) & (labels == 0))\n"," false_pos = np.sum((predictions == 1) & (labels == 0))\n"," false_neg = np.sum((predictions == 0) & (labels == 1))\n","\n"," return {\n"," 'accuracy': accuracy,\n"," 'recall': recall,\n"," 'specificity': specificity,\n"," 'mcc': mcc,\n"," 'roc_auc': roc_auc,\n"," 'precision': precision,\n"," 'f1': f1,\n"," 'true_pos': true_pos,\n"," 'true_neg': true_neg,\n"," 'false_pos': false_pos,\n"," 'false_neg': false_neg\n"," }\n","\n"," # Define the training arguments\n"," training_arguments = TrainingArguments(\n"," output_dir=f\"outputs/{task}/{model_name}_{config_name}\",\n"," num_train_epochs=25,\n"," fp16=False,\n"," bf16=False,\n"," per_device_train_batch_size=20,\n"," per_device_eval_batch_size=10,\n"," gradient_accumulation_steps=2,\n"," gradient_checkpointing=True,\n"," max_grad_norm=0.3,\n"," learning_rate=4e-4,\n"," weight_decay=0.01,\n"," optim=\"paged_adamw_32bit\",\n"," lr_scheduler_type=\"linear\",\n"," max_steps=-1,\n"," warmup_ratio=0.03,\n"," group_by_length=True,\n"," save_steps=1000,\n"," logging_steps=25,\n"," dataloader_pin_memory=False,\n"," report_to='tensorboard',\n"," gradient_checkpointing_kwargs={'use_reentrant': False}\n"," )\n","\n"," trainer = Trainer(\n"," model=model,\n"," args=training_arguments,\n"," train_dataset=train_dataset,\n"," eval_dataset=test_dataset,\n"," tokenizer=tokenizer,\n"," compute_metrics=compute_metrics,\n"," )\n","\n"," start_time = time.time()\n"," trainer.train()\n"," end_time = time.time()\n","\n"," total_time = end_time - start_time\n"," metrics = trainer.evaluate()\n","\n"," return total_time, metrics\n","\n","# Task loop\n","task_list = ['dnasplice', 'tfbs', 'dnaprom']\n","log_file = \"training_log.txt\"\n","model_name = 'fabihamakhdoomi/TinyDNABERT'\n","for task in task_list:\n"," print(f\"Running TASK : {task}\")\n"," train_dataset, test_dataset = create_task_dataset(task)\n"," train_dataset, test_dataset = create_dataset_maps(train_dataset, test_dataset)\n"," train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])\n"," test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])\n","\n","\n"," # Train the base model first\n"," base_model = BertForSequenceClassification.from_pretrained(\n"," model_name,\n"," num_labels=2\n"," )\n"," config_name = \"base_model\"\n"," print(f\"Training MODEL : {config_name} for task : {task}\")\n"," training_time, metrics = train_model(train_dataset, test_dataset, base_model, task, model_name, config_name)\n"," with open(log_file, \"a\") as log:\n"," log.write(f\"Task: {task}, Model: {model_name}, Config: {config_name}, Training Time: {training_time}, Metrics: {metrics}\\n\")\n","\n"," # Train the LoRA models\n"," config_name = \"attention_plus_feed_forward\"\n"," base_model = BertForSequenceClassification.from_pretrained(\n"," model_name,\n"," num_labels=2\n"," )\n"," if task == 'dnasplice':\n"," r_value = 4\n"," print('Setting r value to 4 for dnasplice')\n"," else:\n"," r_value = 8\n"," peft_config = LoraConfig(\n"," lora_alpha=16,\n"," lora_dropout=0.2,\n"," r=r_value,\n"," bias=\"none\",\n"," task_type=\"SEQ_CLS\",\n"," target_modules=attention_plus_feed_forward\n"," )\n"," model = get_peft_model(base_model, peft_config)\n"," print(f\"Training MODEL : {config_name} for task : {task}\")\n"," training_time, metrics = train_model(train_dataset, test_dataset, model, task, model_name, config_name)\n"," with open(log_file, \"a\") as log:\n"," log.write(f\"Task: {task}, Model: {model_name}, Config: {config_name}, Training Time: {training_time}, Metrics: {metrics}\\n\")\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"metadata":{"kaggle":{"accelerator":"gpu","dataSources":[],"dockerImageVersionId":30747,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.13"}},"nbformat":4,"nbformat_minor":4}