from huggingface_hub import hf_hub_download import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.utils.data import DataLoader import re import numpy as np import os import pandas as pd import copy import transformers, datasets from transformers.modeling_outputs import TokenClassifierOutput from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack from transformers.utils.model_parallel_utils import assert_device_map, get_device_map from transformers import T5EncoderModel, T5Tokenizer from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel from transformers import AutoTokenizer from transformers import TrainingArguments, Trainer, set_seed from transformers import DataCollatorForTokenClassification from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union # for custom DataCollator from transformers.data.data_collator import DataCollatorMixin from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy from datasets import Dataset from scipy.special import expit #import peft #from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig cnn_head=True #False set True for Rostlab/prot_t5_xl_half_uniref50-enc ffn_head=False #False transformer_head=False custom_lora=True #False #only true for Rostlab/prot_t5_xl_half_uniref50-enc class ClassConfig: def __init__(self, dropout=0.2, num_labels=3): self.dropout_rate = dropout self.num_labels = num_labels class T5EncoderForTokenClassification(T5PreTrainedModel): def __init__(self, config: T5Config, class_config: ClassConfig): super().__init__(config) self.num_labels = class_config.num_labels self.config = config self.shared = nn.Embedding(config.vocab_size, config.d_model) encoder_config = copy.deepcopy(config) encoder_config.use_cache = False encoder_config.is_encoder_decoder = False self.encoder = T5Stack(encoder_config, self.shared) self.dropout = nn.Dropout(class_config.dropout_rate) # Initialize different heads based on class_config if cnn_head: self.cnn = nn.Conv1d(config.hidden_size, 512, kernel_size=3, padding=1) self.classifier = nn.Linear(512, class_config.num_labels) elif ffn_head: # Multi-layer feed-forward network (FFN) head self.ffn = nn.Sequential( nn.Linear(config.hidden_size, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, class_config.num_labels) ) elif transformer_head: # Transformer layer head encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=8) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) self.classifier = nn.Linear(config.hidden_size, class_config.num_labels) else: # Default classification head self.classifier = nn.Linear(config.hidden_size, class_config.num_labels) self.post_init() # Model parallel self.model_parallel = False self.device_map = None def parallelize(self, device_map=None): self.device_map = ( get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) if device_map is None else device_map ) assert_device_map(self.device_map, len(self.encoder.block)) self.encoder.parallelize(self.device_map) self.classifier = self.classifier.to(self.encoder.first_device) self.model_parallel = True def deparallelize(self): self.encoder.deparallelize() self.encoder = self.encoder.to("cpu") self.model_parallel = False self.device_map = None torch.cuda.empty_cache() def get_input_embeddings(self): return self.shared def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) def get_encoder(self): return self.encoder def _prune_heads(self, heads_to_prune): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def forward( self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) # Forward pass through the selected head if cnn_head: # CNN head sequence_output = sequence_output.permute(0, 2, 1) # Prepare shape for CNN cnn_output = self.cnn(sequence_output) cnn_output = F.relu(cnn_output) cnn_output = cnn_output.permute(0, 2, 1) # Shape back for classifier logits = self.classifier(cnn_output) elif ffn_head: # FFN head logits = self.ffn(sequence_output) elif transformer_head: # Transformer head transformer_output = self.transformer_encoder(sequence_output) logits = self.classifier(transformer_output) else: # Default classification head logits = self.classifier(sequence_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss() active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(-100).type_as(labels) ) valid_logits = active_logits[active_labels != -100] valid_labels = active_labels[active_labels != -100] valid_labels = valid_labels.to(valid_logits.device) valid_labels = valid_labels.long() loss = loss_fct(valid_logits, valid_labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) # Modifies an existing transformer and introduce the LoRA layers class CustomLoRAConfig: def __init__(self): self.lora_rank = 4 self.lora_init_scale = 0.01 self.lora_modules = ".*SelfAttention|.*EncDecAttention" self.lora_layers = "q|k|v|o" self.trainable_param_names = ".*layer_norm.*|.*lora_[ab].*" self.lora_scaling_rank = 1 # lora_modules and lora_layers are speicified with regular expressions # see https://www.w3schools.com/python/python_regex.asp for reference class LoRALinear(nn.Module): def __init__(self, linear_layer, rank, scaling_rank, init_scale): super().__init__() self.in_features = linear_layer.in_features self.out_features = linear_layer.out_features self.rank = rank self.scaling_rank = scaling_rank self.weight = linear_layer.weight self.bias = linear_layer.bias if self.rank > 0: self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale) if init_scale < 0: self.lora_b = nn.Parameter(torch.randn(linear_layer.out_features, rank) * init_scale) else: self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank)) if self.scaling_rank: self.multi_lora_a = nn.Parameter( torch.ones(self.scaling_rank, linear_layer.in_features) + torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale ) if init_scale < 0: self.multi_lora_b = nn.Parameter( torch.ones(linear_layer.out_features, self.scaling_rank) + torch.randn(linear_layer.out_features, self.scaling_rank) * init_scale ) else: self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank)) def forward(self, input): if self.scaling_rank == 1 and self.rank == 0: # parsimonious implementation for ia3 and lora scaling if self.multi_lora_a.requires_grad: hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias) else: hidden = F.linear(input, self.weight, self.bias) if self.multi_lora_b.requires_grad: hidden = hidden * self.multi_lora_b.flatten() return hidden else: # general implementation for lora (adding and scaling) weight = self.weight if self.scaling_rank: weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank if self.rank: weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank return F.linear(input, weight, self.bias) def extra_repr(self): return "in_features={}, out_features={}, bias={}, rank={}, scaling_rank={}".format( self.in_features, self.out_features, self.bias is not None, self.rank, self.scaling_rank ) def modify_with_lora(transformer, config): for m_name, module in dict(transformer.named_modules()).items(): if re.fullmatch(config.lora_modules, m_name): for c_name, layer in dict(module.named_children()).items(): if re.fullmatch(config.lora_layers, c_name): assert isinstance( layer, nn.Linear ), f"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}." setattr( module, c_name, LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale), ) return transformer def load_T5_model_classification(checkpoint, num_labels, half_precision, full = False, deepspeed=True): # Load model and tokenizer if "ankh" in checkpoint : model = T5EncoderModel.from_pretrained(checkpoint) tokenizer = AutoTokenizer.from_pretrained(checkpoint) elif "prot_t5" in checkpoint: # possible to load the half precision model (thanks to @pawel-rezo for pointing that out) if half_precision and deepspeed: #tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False) #model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16)#.to(torch.device('cuda') tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False) model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16).to(torch.device('cuda')) else: model = T5EncoderModel.from_pretrained(checkpoint) tokenizer = T5Tokenizer.from_pretrained(checkpoint) elif "ProstT5" in checkpoint: if half_precision and deepspeed: tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False) model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16).to(torch.device('cuda')) else: model = T5EncoderModel.from_pretrained(checkpoint) tokenizer = T5Tokenizer.from_pretrained(checkpoint) # Create new Classifier model with PT5 dimensions class_config=ClassConfig(num_labels=num_labels) class_model=T5EncoderForTokenClassification(model.config,class_config) # Set encoder and embedding weights to checkpoint weights class_model.shared=model.shared class_model.encoder=model.encoder # Delete the checkpoint model model=class_model del class_model if full == True: return model, tokenizer # Print number of trainable parameters model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print("T5_Classfier\nTrainable Parameter: "+ str(params)) if custom_lora: #the linear CustomLoRAConfig allows better quality predictions, but more memory is needed # Add model modification lora config = CustomLoRAConfig() # Add LoRA layers model = modify_with_lora(model, config) # Freeze Embeddings and Encoder (except LoRA) for (param_name, param) in model.shared.named_parameters(): param.requires_grad = False for (param_name, param) in model.encoder.named_parameters(): param.requires_grad = False for (param_name, param) in model.named_parameters(): if re.fullmatch(config.trainable_param_names, param_name): param.requires_grad = True else: # lora modification peft_config = LoraConfig( r=4, lora_alpha=1, bias="all", target_modules=["q","k","v","o"] ) model = inject_adapter_in_model(peft_config, model) # Unfreeze the prediction head for (param_name, param) in model.classifier.named_parameters(): param.requires_grad = True # Print trainable Parameter model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print("T5_LoRA_Classfier\nTrainable Parameter: "+ str(params) + "\n") return model, tokenizer class EsmForTokenClassificationCustom(EsmPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"cnn", r"ffn", r"transformer"] def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.esm = EsmModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) if cnn_head: self.cnn = nn.Conv1d(config.hidden_size, 512, kernel_size=3, padding=1) self.classifier = nn.Linear(512, config.num_labels) elif ffn_head: # Multi-layer feed-forward network (FFN) as an alternative head self.ffn = nn.Sequential( nn.Linear(config.hidden_size, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, config.num_labels) ) elif transformer_head: # Transformer layer as an alternative head encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=8) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) self.classifier = nn.Linear(config.hidden_size, config.num_labels) else: # Default classification head self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, TokenClassifierOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.esm( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) if cnn_head: sequence_output = sequence_output.transpose(1, 2) sequence_output = self.cnn(sequence_output) sequence_output = sequence_output.transpose(1, 2) logits = self.classifier(sequence_output) elif ffn_head: logits = self.ffn(sequence_output) elif transformer_head: # Apply transformer encoder for the transformer head sequence_output = self.transformer_encoder(sequence_output) logits = self.classifier(sequence_output) else: logits = self.classifier(sequence_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss() active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(-100).type_as(labels) ) valid_logits = active_logits[active_labels != -100] valid_labels = active_labels[active_labels != -100] valid_labels = valid_labels.type(torch.LongTensor).to('cuda:0') loss = loss_fct(valid_logits, valid_labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def _init_weights(self, module): if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() # based on transformers DataCollatorForTokenClassification @dataclass class DataCollatorForTokenClassificationESM(DataCollatorMixin): """ Data collator that will dynamically pad the inputs received, as well as the labels. Args: tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): The tokenizer used for encoding the data. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single sequence is provided). - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum acceptable input length for the model if that argument is not provided. - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). max_length (`int`, *optional*): Maximum length of the returned list and optionally padding length (see above). pad_to_multiple_of (`int`, *optional*): If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). label_pad_token_id (`int`, *optional*, defaults to -100): The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions). return_tensors (`str`): The type of Tensor to return. Allowable values are "np", "pt" and "tf". """ tokenizer: PreTrainedTokenizerBase padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None pad_to_multiple_of: Optional[int] = None label_pad_token_id: int = -100 return_tensors: str = "pt" def torch_call(self, features): import torch label_name = "label" if "label" in features[0].keys() else "labels" labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features] batch = self.tokenizer.pad( no_labels_features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt", ) if labels is None: return batch sequence_length = batch["input_ids"].shape[1] padding_side = self.tokenizer.padding_side def to_list(tensor_or_iterable): if isinstance(tensor_or_iterable, torch.Tensor): return tensor_or_iterable.tolist() return list(tensor_or_iterable) if padding_side == "right": batch[label_name] = [ # to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels # changed to pad the special tokens at the beginning and end of the sequence [self.label_pad_token_id] + to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)-1) for label in labels ] else: batch[label_name] = [ [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels ] batch[label_name] = torch.tensor(batch[label_name], dtype=torch.float) return batch def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None): """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" import torch # Tensorize if necessary. if isinstance(examples[0], (list, tuple, np.ndarray)): examples = [torch.tensor(e, dtype=torch.long) for e in examples] length_of_first = examples[0].size(0) # Check if padding is necessary. are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0): return torch.stack(examples, dim=0) # If yes, check if we have a `pad_token`. if tokenizer._pad_token is None: raise ValueError( "You are attempting to pad samples but the tokenizer you are using" f" ({tokenizer.__class__.__name__}) does not have a pad token." ) # Creating the full tensor and filling it with our data. max_length = max(x.size(0) for x in examples) if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id) for i, example in enumerate(examples): if tokenizer.padding_side == "right": result[i, : example.shape[0]] = example else: result[i, -example.shape[0] :] = example return result def tolist(x): if isinstance(x, list): return x elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import x = x.numpy() return x.tolist() #load ESM2 models def load_esm_model_classification(checkpoint, num_labels, half_precision, full=False, deepspeed=True): tokenizer = AutoTokenizer.from_pretrained(checkpoint) if half_precision and deepspeed: model = EsmForTokenClassificationCustom.from_pretrained(checkpoint, num_labels = num_labels, ignore_mismatched_sizes=True, torch_dtype = torch.float16) else: model = EsmForTokenClassificationCustom.from_pretrained(checkpoint, num_labels = num_labels, ignore_mismatched_sizes=True) if full == True: return model, tokenizer peft_config = LoraConfig( r=4, lora_alpha=1, bias="all", target_modules=["query","key","value","dense"] ) model = inject_adapter_in_model(peft_config, model) #model.gradient_checkpointing_enable() # Unfreeze the prediction head for (param_name, param) in model.classifier.named_parameters(): param.requires_grad = True return model, tokenizer def load_model(checkpoint,max_length): #checkpoint='ThorbenF/prot_t5_xl_uniref50' #best_model_path='ThorbenF/prot_t5_xl_uniref50/cpt.pth' full=False deepspeed=False mixed=False num_labels=2 print(checkpoint, num_labels, mixed, full, deepspeed) # Determine model type and load accordingly if "esm" in checkpoint: model, tokenizer = load_esm_model_classification(checkpoint, num_labels, mixed, full, deepspeed) else: model, tokenizer = load_T5_model_classification(checkpoint, num_labels, mixed, full, deepspeed) # Download the file local_file = hf_hub_download(repo_id=checkpoint, filename="cpt.pth") # Load the best model state state_dict = torch.load(local_file, map_location=torch.device('cpu'), weights_only=True) model.load_state_dict(state_dict) return model, tokenizer