{ "cells": [ { "cell_type": "code", "execution_count": 38, "id": "14ff5741-629c-445a-a8a9-b3d9db1f3ddb", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n", "from torch.utils.data import DataLoader\n", "\n", "import re\n", "import numpy as np\n", "import os\n", "import pandas as pd\n", "import copy\n", "\n", "import transformers, datasets\n", "from transformers.modeling_outputs import TokenClassifierOutput\n", "from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack\n", "from transformers.utils.model_parallel_utils import assert_device_map, get_device_map\n", "from transformers import T5EncoderModel, T5Tokenizer\n", "from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel\n", "from transformers import AutoTokenizer\n", "from transformers import TrainingArguments, Trainer, set_seed\n", "from transformers import DataCollatorForTokenClassification\n", "\n", "from dataclasses import dataclass\n", "from typing import Dict, List, Optional, Tuple, Union\n", "\n", "# for custom DataCollator\n", "from transformers.data.data_collator import DataCollatorMixin\n", "from transformers.tokenization_utils_base import PreTrainedTokenizerBase\n", "from transformers.utils import PaddingStrategy\n", "\n", "from datasets import Dataset\n", "\n", "from scipy.special import expit\n", "\n", "import peft\n", "from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig" ] }, { "cell_type": "code", "execution_count": 6, "id": "5ec16a71-ed5d-46a6-98b2-55bc5d0fbe07", "metadata": {}, "outputs": [], "source": [ "cnn_head=True #False set True for Rostlab/prot_t5_xl_half_uniref50-enc\n", "ffn_head=False #False\n", "transformer_head=False\n", "custom_lora=True #False #only true for Rostlab/prot_t5_xl_half_uniref50-enc" ] }, { "cell_type": "code", "execution_count": 8, "id": "cc7151ca-0daf-4e75-a865-ab52f9b28f2e", "metadata": {}, "outputs": [], "source": [ "class ClassConfig:\n", " def __init__(self, dropout=0.2, num_labels=3):\n", " self.dropout_rate = dropout\n", " self.num_labels = num_labels\n", "\n", "class T5EncoderForTokenClassification(T5PreTrainedModel):\n", "\n", " def __init__(self, config: T5Config, class_config: ClassConfig):\n", " super().__init__(config)\n", " self.num_labels = class_config.num_labels\n", " self.config = config\n", "\n", " self.shared = nn.Embedding(config.vocab_size, config.d_model)\n", "\n", " encoder_config = copy.deepcopy(config)\n", " encoder_config.use_cache = False\n", " encoder_config.is_encoder_decoder = False\n", " self.encoder = T5Stack(encoder_config, self.shared)\n", "\n", " self.dropout = nn.Dropout(class_config.dropout_rate)\n", "\n", " # Initialize different heads based on class_config\n", " if cnn_head:\n", " self.cnn = nn.Conv1d(config.hidden_size, 512, kernel_size=3, padding=1)\n", " self.classifier = nn.Linear(512, class_config.num_labels)\n", " elif ffn_head:\n", " # Multi-layer feed-forward network (FFN) head\n", " self.ffn = nn.Sequential(\n", " nn.Linear(config.hidden_size, 512),\n", " nn.ReLU(),\n", " nn.Linear(512, 256),\n", " nn.ReLU(),\n", " nn.Linear(256, class_config.num_labels)\n", " )\n", " elif transformer_head:\n", " # Transformer layer head\n", " encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=8)\n", " self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)\n", " self.classifier = nn.Linear(config.hidden_size, class_config.num_labels)\n", " else:\n", " # Default classification head\n", " self.classifier = nn.Linear(config.hidden_size, class_config.num_labels)\n", " \n", " self.post_init()\n", "\n", " # Model parallel\n", " self.model_parallel = False\n", " self.device_map = None\n", "\n", " def parallelize(self, device_map=None):\n", " self.device_map = (\n", " get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))\n", " if device_map is None\n", " else device_map\n", " )\n", " assert_device_map(self.device_map, len(self.encoder.block))\n", " self.encoder.parallelize(self.device_map)\n", " self.classifier = self.classifier.to(self.encoder.first_device)\n", " self.model_parallel = True\n", "\n", " def deparallelize(self):\n", " self.encoder.deparallelize()\n", " self.encoder = self.encoder.to(\"cpu\")\n", " self.model_parallel = False\n", " self.device_map = None\n", " torch.cuda.empty_cache()\n", "\n", " def get_input_embeddings(self):\n", " return self.shared\n", "\n", " def set_input_embeddings(self, new_embeddings):\n", " self.shared = new_embeddings\n", " self.encoder.set_input_embeddings(new_embeddings)\n", "\n", " def get_encoder(self):\n", " return self.encoder\n", "\n", " def _prune_heads(self, heads_to_prune):\n", " for layer, heads in heads_to_prune.items():\n", " self.encoder.layer[layer].attention.prune_heads(heads)\n", "\n", " def forward(\n", " self,\n", " input_ids=None,\n", " attention_mask=None,\n", " head_mask=None,\n", " inputs_embeds=None,\n", " labels=None,\n", " output_attentions=None,\n", " output_hidden_states=None,\n", " return_dict=None,\n", " ):\n", " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n", "\n", " outputs = self.encoder(\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " inputs_embeds=inputs_embeds,\n", " head_mask=head_mask,\n", " output_attentions=output_attentions,\n", " output_hidden_states=output_hidden_states,\n", " return_dict=return_dict,\n", " )\n", "\n", " sequence_output = outputs[0]\n", " sequence_output = self.dropout(sequence_output)\n", "\n", " # Forward pass through the selected head\n", " if cnn_head:\n", " # CNN head\n", " sequence_output = sequence_output.permute(0, 2, 1) # Prepare shape for CNN\n", " cnn_output = self.cnn(sequence_output)\n", " cnn_output = F.relu(cnn_output)\n", " cnn_output = cnn_output.permute(0, 2, 1) # Shape back for classifier\n", " logits = self.classifier(cnn_output)\n", " elif ffn_head:\n", " # FFN head\n", " logits = self.ffn(sequence_output)\n", " elif transformer_head:\n", " # Transformer head\n", " transformer_output = self.transformer_encoder(sequence_output)\n", " logits = self.classifier(transformer_output)\n", " else:\n", " # Default classification head\n", " logits = self.classifier(sequence_output)\n", "\n", " loss = None\n", " if labels is not None:\n", " loss_fct = CrossEntropyLoss()\n", " active_loss = attention_mask.view(-1) == 1\n", " active_logits = logits.view(-1, self.num_labels)\n", " active_labels = torch.where(\n", " active_loss, labels.view(-1), torch.tensor(-100).type_as(labels)\n", " )\n", " valid_logits = active_logits[active_labels != -100]\n", " valid_labels = active_labels[active_labels != -100]\n", " valid_labels = valid_labels.to(valid_logits.device)\n", " valid_labels = valid_labels.long()\n", " loss = loss_fct(valid_logits, valid_labels)\n", "\n", " if not return_dict:\n", " output = (logits,) + outputs[2:]\n", " return ((loss,) + output) if loss is not None else output\n", "\n", " return TokenClassifierOutput(\n", " loss=loss,\n", " logits=logits,\n", " hidden_states=outputs.hidden_states,\n", " attentions=outputs.attentions,\n", " )" ] }, { "cell_type": "code", "execution_count": 10, "id": "e5e751ba-f4d3-4a28-bea0-82633f1dabb4", "metadata": {}, "outputs": [], "source": [ "# Modifies an existing transformer and introduce the LoRA layers\n", "\n", "class CustomLoRAConfig:\n", " def __init__(self):\n", " self.lora_rank = 4\n", " self.lora_init_scale = 0.01\n", " self.lora_modules = \".*SelfAttention|.*EncDecAttention\"\n", " self.lora_layers = \"q|k|v|o\"\n", " self.trainable_param_names = \".*layer_norm.*|.*lora_[ab].*\"\n", " self.lora_scaling_rank = 1\n", " # lora_modules and lora_layers are speicified with regular expressions\n", " # see https://www.w3schools.com/python/python_regex.asp for reference\n", " \n", "class LoRALinear(nn.Module):\n", " def __init__(self, linear_layer, rank, scaling_rank, init_scale):\n", " super().__init__()\n", " self.in_features = linear_layer.in_features\n", " self.out_features = linear_layer.out_features\n", " self.rank = rank\n", " self.scaling_rank = scaling_rank\n", " self.weight = linear_layer.weight\n", " self.bias = linear_layer.bias\n", " if self.rank > 0:\n", " self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale)\n", " if init_scale < 0:\n", " self.lora_b = nn.Parameter(torch.randn(linear_layer.out_features, rank) * init_scale)\n", " else:\n", " self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank))\n", " if self.scaling_rank:\n", " self.multi_lora_a = nn.Parameter(\n", " torch.ones(self.scaling_rank, linear_layer.in_features)\n", " + torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale\n", " )\n", " if init_scale < 0:\n", " self.multi_lora_b = nn.Parameter(\n", " torch.ones(linear_layer.out_features, self.scaling_rank)\n", " + torch.randn(linear_layer.out_features, self.scaling_rank) * init_scale\n", " )\n", " else:\n", " self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank))\n", "\n", " def forward(self, input):\n", " if self.scaling_rank == 1 and self.rank == 0:\n", " # parsimonious implementation for ia3 and lora scaling\n", " if self.multi_lora_a.requires_grad:\n", " hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias)\n", " else:\n", " hidden = F.linear(input, self.weight, self.bias)\n", " if self.multi_lora_b.requires_grad:\n", " hidden = hidden * self.multi_lora_b.flatten()\n", " return hidden\n", " else:\n", " # general implementation for lora (adding and scaling)\n", " weight = self.weight\n", " if self.scaling_rank:\n", " weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank\n", " if self.rank:\n", " weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank\n", " return F.linear(input, weight, self.bias)\n", "\n", " def extra_repr(self):\n", " return \"in_features={}, out_features={}, bias={}, rank={}, scaling_rank={}\".format(\n", " self.in_features, self.out_features, self.bias is not None, self.rank, self.scaling_rank\n", " )\n", "\n", "\n", "def modify_with_lora(transformer, config):\n", " for m_name, module in dict(transformer.named_modules()).items():\n", " if re.fullmatch(config.lora_modules, m_name):\n", " for c_name, layer in dict(module.named_children()).items():\n", " if re.fullmatch(config.lora_layers, c_name):\n", " assert isinstance(\n", " layer, nn.Linear\n", " ), f\"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}.\"\n", " setattr(\n", " module,\n", " c_name,\n", " LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale),\n", " )\n", " return transformer\n", "\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "43a56311-3279-466a-bc95-590381f1b13c", "metadata": {}, "outputs": [], "source": [ "def load_T5_model_classification(checkpoint, num_labels, half_precision, full = False, deepspeed=True):\n", " # Load model and tokenizer\n", "\n", " if \"ankh\" in checkpoint :\n", " model = T5EncoderModel.from_pretrained(checkpoint)\n", " tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n", "\n", " elif \"prot_t5\" in checkpoint:\n", " # possible to load the half precision model (thanks to @pawel-rezo for pointing that out)\n", " if half_precision and deepspeed:\n", " #tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)\n", " #model = T5EncoderModel.from_pretrained(\"Rostlab/prot_t5_xl_half_uniref50-enc\", torch_dtype=torch.float16)#.to(torch.device('cuda')\n", " tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False)\n", " model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16).to(torch.device('cuda'))\n", " else:\n", " model = T5EncoderModel.from_pretrained(checkpoint)\n", " tokenizer = T5Tokenizer.from_pretrained(checkpoint)\n", " \n", " elif \"ProstT5\" in checkpoint:\n", " if half_precision and deepspeed: \n", " tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False)\n", " model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16).to(torch.device('cuda'))\n", " else:\n", " model = T5EncoderModel.from_pretrained(checkpoint)\n", " tokenizer = T5Tokenizer.from_pretrained(checkpoint) \n", " \n", " # Create new Classifier model with PT5 dimensions\n", " class_config=ClassConfig(num_labels=num_labels)\n", " class_model=T5EncoderForTokenClassification(model.config,class_config)\n", " \n", " # Set encoder and embedding weights to checkpoint weights\n", " class_model.shared=model.shared\n", " class_model.encoder=model.encoder \n", " \n", " # Delete the checkpoint model\n", " model=class_model\n", " del class_model\n", " \n", " if full == True:\n", " return model, tokenizer \n", " \n", " # Print number of trainable parameters\n", " model_parameters = filter(lambda p: p.requires_grad, model.parameters())\n", " params = sum([np.prod(p.size()) for p in model_parameters])\n", " print(\"T5_Classfier\\nTrainable Parameter: \"+ str(params)) \n", "\n", " if custom_lora:\n", " #the linear CustomLoRAConfig allows better quality predictions, but more memory is needed\n", " # Add model modification lora\n", " config = CustomLoRAConfig()\n", " \n", " # Add LoRA layers\n", " model = modify_with_lora(model, config)\n", " \n", " # Freeze Embeddings and Encoder (except LoRA)\n", " for (param_name, param) in model.shared.named_parameters():\n", " param.requires_grad = False\n", " for (param_name, param) in model.encoder.named_parameters():\n", " param.requires_grad = False \n", " \n", " for (param_name, param) in model.named_parameters():\n", " if re.fullmatch(config.trainable_param_names, param_name):\n", " param.requires_grad = True\n", "\n", " else:\n", " # lora modification\n", " peft_config = LoraConfig(\n", " r=4, lora_alpha=1, bias=\"all\", target_modules=[\"q\",\"k\",\"v\",\"o\"]\n", " )\n", " \n", " model = inject_adapter_in_model(peft_config, model)\n", " \n", " # Unfreeze the prediction head\n", " for (param_name, param) in model.classifier.named_parameters():\n", " param.requires_grad = True \n", "\n", " # Print trainable Parameter \n", " model_parameters = filter(lambda p: p.requires_grad, model.parameters())\n", " params = sum([np.prod(p.size()) for p in model_parameters])\n", " print(\"T5_LoRA_Classfier\\nTrainable Parameter: \"+ str(params) + \"\\n\")\n", " \n", " return model, tokenizer" ] }, { "cell_type": "code", "execution_count": 14, "id": "7ba720bc-a003-4984-a965-cb2f42344e85", "metadata": {}, "outputs": [], "source": [ "class EsmForTokenClassificationCustom(EsmPreTrainedModel):\n", " _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n", " _keys_to_ignore_on_load_missing = [r\"position_ids\", r\"cnn\", r\"ffn\", r\"transformer\"]\n", "\n", " def __init__(self, config):\n", " super().__init__(config)\n", " self.num_labels = config.num_labels\n", " self.esm = EsmModel(config, add_pooling_layer=False)\n", " self.dropout = nn.Dropout(config.hidden_dropout_prob)\n", "\n", " if cnn_head:\n", " self.cnn = nn.Conv1d(config.hidden_size, 512, kernel_size=3, padding=1)\n", " self.classifier = nn.Linear(512, config.num_labels)\n", " elif ffn_head:\n", " # Multi-layer feed-forward network (FFN) as an alternative head\n", " self.ffn = nn.Sequential(\n", " nn.Linear(config.hidden_size, 512),\n", " nn.ReLU(),\n", " nn.Linear(512, 256),\n", " nn.ReLU(),\n", " nn.Linear(256, config.num_labels)\n", " )\n", " elif transformer_head:\n", " # Transformer layer as an alternative head\n", " encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=8)\n", " self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)\n", " self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n", " else:\n", " # Default classification head\n", " self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n", "\n", " self.init_weights()\n", "\n", " def forward(\n", " self,\n", " input_ids: Optional[torch.LongTensor] = None,\n", " attention_mask: Optional[torch.Tensor] = None,\n", " position_ids: Optional[torch.LongTensor] = None,\n", " head_mask: Optional[torch.Tensor] = None,\n", " inputs_embeds: Optional[torch.FloatTensor] = None,\n", " labels: Optional[torch.LongTensor] = None,\n", " output_attentions: Optional[bool] = None,\n", " output_hidden_states: Optional[bool] = None,\n", " return_dict: Optional[bool] = None,\n", " ) -> Union[Tuple, TokenClassifierOutput]:\n", " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n", " outputs = self.esm(\n", " input_ids,\n", " attention_mask=attention_mask,\n", " position_ids=position_ids,\n", " head_mask=head_mask,\n", " inputs_embeds=inputs_embeds,\n", " output_attentions=output_attentions,\n", " output_hidden_states=output_hidden_states,\n", " return_dict=return_dict,\n", " )\n", " \n", " sequence_output = outputs[0]\n", " sequence_output = self.dropout(sequence_output)\n", "\n", " if cnn_head:\n", " sequence_output = sequence_output.transpose(1, 2)\n", " sequence_output = self.cnn(sequence_output)\n", " sequence_output = sequence_output.transpose(1, 2)\n", " logits = self.classifier(sequence_output)\n", " elif ffn_head:\n", " logits = self.ffn(sequence_output)\n", " elif transformer_head:\n", " # Apply transformer encoder for the transformer head\n", " sequence_output = self.transformer_encoder(sequence_output)\n", " logits = self.classifier(sequence_output)\n", " else:\n", " logits = self.classifier(sequence_output)\n", "\n", " loss = None\n", " if labels is not None:\n", " loss_fct = CrossEntropyLoss()\n", " active_loss = attention_mask.view(-1) == 1\n", " active_logits = logits.view(-1, self.num_labels)\n", " active_labels = torch.where(\n", " active_loss, labels.view(-1), torch.tensor(-100).type_as(labels)\n", " )\n", " valid_logits = active_logits[active_labels != -100]\n", " valid_labels = active_labels[active_labels != -100]\n", " valid_labels = valid_labels.type(torch.LongTensor).to('cuda:0')\n", " loss = loss_fct(valid_logits, valid_labels)\n", "\n", " if not return_dict:\n", " output = (logits,) + outputs[2:]\n", " return ((loss,) + output) if loss is not None else output\n", "\n", " return TokenClassifierOutput(\n", " loss=loss,\n", " logits=logits,\n", " hidden_states=outputs.hidden_states,\n", " attentions=outputs.attentions,\n", " )\n", "\n", " def _init_weights(self, module):\n", " if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d):\n", " module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n", " if module.bias is not None:\n", " module.bias.data.zero_()\n", "\n", "# based on transformers DataCollatorForTokenClassification\n", "@dataclass\n", "class DataCollatorForTokenClassificationESM(DataCollatorMixin):\n", " \"\"\"\n", " Data collator that will dynamically pad the inputs received, as well as the labels.\n", " Args:\n", " tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):\n", " The tokenizer used for encoding the data.\n", " padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):\n", " Select a strategy to pad the returned sequences (according to the model's padding side and padding index)\n", " among:\n", " - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single\n", " sequence is provided).\n", " - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum\n", " acceptable input length for the model if that argument is not provided.\n", " - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).\n", " max_length (`int`, *optional*):\n", " Maximum length of the returned list and optionally padding length (see above).\n", " pad_to_multiple_of (`int`, *optional*):\n", " If set will pad the sequence to a multiple of the provided value.\n", " This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=\n", " 7.5 (Volta).\n", " label_pad_token_id (`int`, *optional*, defaults to -100):\n", " The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).\n", " return_tensors (`str`):\n", " The type of Tensor to return. Allowable values are \"np\", \"pt\" and \"tf\".\n", " \"\"\"\n", "\n", " tokenizer: PreTrainedTokenizerBase\n", " padding: Union[bool, str, PaddingStrategy] = True\n", " max_length: Optional[int] = None\n", " pad_to_multiple_of: Optional[int] = None\n", " label_pad_token_id: int = -100\n", " return_tensors: str = \"pt\"\n", "\n", " def torch_call(self, features):\n", " import torch\n", "\n", " label_name = \"label\" if \"label\" in features[0].keys() else \"labels\"\n", " labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None\n", "\n", " no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]\n", "\n", " batch = self.tokenizer.pad(\n", " no_labels_features,\n", " padding=self.padding,\n", " max_length=self.max_length,\n", " pad_to_multiple_of=self.pad_to_multiple_of,\n", " return_tensors=\"pt\",\n", " )\n", "\n", " if labels is None:\n", " return batch\n", "\n", " sequence_length = batch[\"input_ids\"].shape[1]\n", " padding_side = self.tokenizer.padding_side\n", "\n", " def to_list(tensor_or_iterable):\n", " if isinstance(tensor_or_iterable, torch.Tensor):\n", " return tensor_or_iterable.tolist()\n", " return list(tensor_or_iterable)\n", "\n", " if padding_side == \"right\":\n", " batch[label_name] = [\n", " # to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels\n", " # changed to pad the special tokens at the beginning and end of the sequence\n", " [self.label_pad_token_id] + to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)-1) for label in labels\n", " ]\n", " else:\n", " batch[label_name] = [\n", " [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels\n", " ]\n", "\n", " batch[label_name] = torch.tensor(batch[label_name], dtype=torch.float)\n", " return batch\n", "\n", "def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):\n", " \"\"\"Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.\"\"\"\n", " import torch\n", "\n", " # Tensorize if necessary.\n", " if isinstance(examples[0], (list, tuple, np.ndarray)):\n", " examples = [torch.tensor(e, dtype=torch.long) for e in examples]\n", "\n", " length_of_first = examples[0].size(0)\n", "\n", " # Check if padding is necessary.\n", "\n", " are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)\n", " if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):\n", " return torch.stack(examples, dim=0)\n", "\n", " # If yes, check if we have a `pad_token`.\n", " if tokenizer._pad_token is None:\n", " raise ValueError(\n", " \"You are attempting to pad samples but the tokenizer you are using\"\n", " f\" ({tokenizer.__class__.__name__}) does not have a pad token.\"\n", " )\n", "\n", " # Creating the full tensor and filling it with our data.\n", " max_length = max(x.size(0) for x in examples)\n", " if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):\n", " max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of\n", " result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)\n", " for i, example in enumerate(examples):\n", " if tokenizer.padding_side == \"right\":\n", " result[i, : example.shape[0]] = example\n", " else:\n", " result[i, -example.shape[0] :] = example\n", " return result\n", "\n", "def tolist(x):\n", " if isinstance(x, list):\n", " return x\n", " elif hasattr(x, \"numpy\"): # Checks for TF tensors without needing the import\n", " x = x.numpy()\n", " return x.tolist()" ] }, { "cell_type": "code", "execution_count": 16, "id": "ea511812-1244-4e51-b63c-b4da7822f0b7", "metadata": {}, "outputs": [], "source": [ "#load ESM2 models\n", "def load_esm_model_classification(checkpoint, num_labels, half_precision, full=False, deepspeed=True):\n", " \n", " tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n", "\n", " \n", " if half_precision and deepspeed:\n", " model = EsmForTokenClassificationCustom.from_pretrained(checkpoint, \n", " num_labels = num_labels, \n", " ignore_mismatched_sizes=True,\n", " torch_dtype = torch.float16)\n", " else:\n", " model = EsmForTokenClassificationCustom.from_pretrained(checkpoint, \n", " num_labels = num_labels,\n", " ignore_mismatched_sizes=True)\n", " \n", " if full == True:\n", " return model, tokenizer \n", " \n", " peft_config = LoraConfig(\n", " r=4, lora_alpha=1, bias=\"all\", target_modules=[\"query\",\"key\",\"value\",\"dense\"]\n", " )\n", " \n", " model = inject_adapter_in_model(peft_config, model)\n", "\n", " #model.gradient_checkpointing_enable()\n", " \n", " # Unfreeze the prediction head\n", " for (param_name, param) in model.classifier.named_parameters():\n", " param.requires_grad = True \n", " \n", " return model, tokenizer" ] }, { "cell_type": "code", "execution_count": 22, "id": "8941bbbb-57c5-4f3d-89d9-12b2d306e7a1", "metadata": {}, "outputs": [], "source": [ "checkpoint='../Pretrained/Rostlab/prot_t5_xl_uniref50'\n", "best_model_path='../refined_models/ChallengeFinetuning/Rostlab/prot_t5_xl_uniref50/manual_checkpoint/cpt.pth'\n", "full=False\n", "deepspeed=False\n", "mixed=False \n", "num_labels=2" ] }, { "cell_type": "code", "execution_count": null, "id": "4f007331-34d4-4c1d-9311-e91db23d9ed5", "metadata": {}, "outputs": [], "source": [ "/home/frohlkin/Projects/PLM/Publication/hf_webpage/pretrained" ] }, { "cell_type": "code", "execution_count": 24, "id": "18d4ad06-b195-4cc6-a3c8-fa3e761838dc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "../Pretrained/Rostlab/prot_t5_xl_uniref50 2 False False False\n", "T5_Classfier\n", "Trainable Parameter: 1209716226\n", "T5_LoRA_Classfier\n", "Trainable Parameter: 4082178\n", "\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(checkpoint, num_labels, mixed, full, deepspeed)\n", " \n", "# Determine model type and load accordingly\n", "if \"esm\" in checkpoint:\n", " model, tokenizer = load_esm_model_classification(checkpoint, num_labels, mixed, full, deepspeed)\n", "else:\n", " model, tokenizer = load_T5_model_classification(checkpoint, num_labels, mixed, full, deepspeed)\n", "\n", "# Load the best model state\n", "state_dict = torch.load(best_model_path, weights_only=True)\n", "model.load_state_dict(state_dict)" ] }, { "cell_type": "code", "execution_count": 30, "id": "4e215923-dfe2-4426-aedf-5cb81f7f0db2", "metadata": {}, "outputs": [], "source": [ "test_one_letter_sequence='AWYAAK'\n", "max_length=1500" ] }, { "cell_type": "code", "execution_count": 40, "id": "7174ea02-ed51-46f5-84c0-6bcd760670d4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(7,)" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def create_dataset(tokenizer,seqs,labels,checkpoint):\n", " \n", " tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)\n", " dataset = Dataset.from_dict(tokenized)\n", " \n", " if (\"esm\" in checkpoint) or (\"ProstT5\" in checkpoint):\n", " labels = [l[:max_length-2] for l in labels] \n", " else:\n", " labels = [l[:max_length-1] for l in labels] \n", " \n", " dataset = dataset.add_column(\"labels\", labels)\n", " \n", " return dataset\n", " \n", "def convert_predictions(input_logits):\n", " all_probs = []\n", " for logits in input_logits:\n", " logits = logits.reshape(-1, 2)\n", "\n", " # Mask out irrelevant regions\n", " # Compute probabilities for class 1\n", " probabilities_class1 = expit(logits[:, 1] - logits[:, 0])\n", " \n", " all_probs.append(probabilities_class1)\n", " \n", " return np.concatenate(all_probs)\n", " \n", " \n", "dummy_labels=[np.zeros(len(test_one_letter_sequence))]\n", "# Replace uncommon amino acids with \"X\"\n", "test_one_letter_sequence = test_one_letter_sequence.replace(\"O\", \"X\").replace(\"B\", \"X\").replace(\"U\", \"X\").replace(\"Z\", \"X\").replace(\"J\", \"X\")\n", "\n", "# Add spaces between each amino acid for ProtT5 and ProstT5 models\n", "if \"Rostlab\" in checkpoint:\n", " test_one_letter_sequence = \" \".join(test_one_letter_sequence)\n", "\n", "# Add for ProstT5 model input format\n", "if \"ProstT5\" in checkpoint:\n", " test_one_letter_sequence = \" \" + test_one_letter_sequence\n", " \n", "test_dataset=create_dataset(tokenizer,[test_one_letter_sequence],dummy_labels,checkpoint)\n", "\n", "if (\"esm\" in checkpoint) or (\"ProstT5\" in checkpoint):\n", " data_collator = DataCollatorForTokenClassificationESM(tokenizer)\n", "else:\n", " data_collator = DataCollatorForTokenClassification(tokenizer)\n", "\n", "test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model.to(device)\n", "for batch in test_loader:\n", " input_ids = batch['input_ids'].to(device)\n", " attention_mask = batch['attention_mask'].to(device)\n", " labels = batch['labels'] # Ensure to get labels from batch\n", "\n", " outputs = model(input_ids, attention_mask=attention_mask)\n", " logits = outputs.logits.detach().cpu().numpy()\n", "\n", "logits=convert_predictions(logits)\n", "logits.shape\n", "\n", "def normalize_scores(scores):\n", " min_score = np.min(scores)\n", " max_score = np.max(scores)\n", " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n", "\n", "normalized_scores = normalize_scores(logits)\n", "\n", "normalized_scores.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "58b5ae4d-9e8e-4d07-ab46-76d23cc29016", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:LLM] *", "language": "python", "name": "conda-env-LLM-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 5 }