Spaces:
Running
Running
from huggingface_hub import hf_hub_download | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
from torch.utils.data import DataLoader | |
import re | |
import numpy as np | |
import os | |
import pandas as pd | |
import copy | |
import transformers, datasets | |
from transformers.modeling_outputs import TokenClassifierOutput | |
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack | |
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map | |
from transformers import T5EncoderModel, T5Tokenizer | |
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel | |
from transformers import AutoTokenizer | |
from transformers import TrainingArguments, Trainer, set_seed | |
from transformers import DataCollatorForTokenClassification | |
from dataclasses import dataclass | |
from typing import Dict, List, Optional, Tuple, Union | |
# for custom DataCollator | |
from transformers.data.data_collator import DataCollatorMixin | |
from transformers.tokenization_utils_base import PreTrainedTokenizerBase | |
from transformers.utils import PaddingStrategy | |
from datasets import Dataset | |
from scipy.special import expit | |
#import peft | |
#from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig | |
cnn_head=True #False set True for Rostlab/prot_t5_xl_half_uniref50-enc | |
ffn_head=False #False | |
transformer_head=False | |
custom_lora=True #False #only true for Rostlab/prot_t5_xl_half_uniref50-enc | |
class ClassConfig: | |
def __init__(self, dropout=0.2, num_labels=3): | |
self.dropout_rate = dropout | |
self.num_labels = num_labels | |
class T5EncoderForTokenClassification(T5PreTrainedModel): | |
def __init__(self, config: T5Config, class_config: ClassConfig): | |
super().__init__(config) | |
self.num_labels = class_config.num_labels | |
self.config = config | |
self.shared = nn.Embedding(config.vocab_size, config.d_model) | |
encoder_config = copy.deepcopy(config) | |
encoder_config.use_cache = False | |
encoder_config.is_encoder_decoder = False | |
self.encoder = T5Stack(encoder_config, self.shared) | |
self.dropout = nn.Dropout(class_config.dropout_rate) | |
# Initialize different heads based on class_config | |
if cnn_head: | |
self.cnn = nn.Conv1d(config.hidden_size, 512, kernel_size=3, padding=1) | |
self.classifier = nn.Linear(512, class_config.num_labels) | |
elif ffn_head: | |
# Multi-layer feed-forward network (FFN) head | |
self.ffn = nn.Sequential( | |
nn.Linear(config.hidden_size, 512), | |
nn.ReLU(), | |
nn.Linear(512, 256), | |
nn.ReLU(), | |
nn.Linear(256, class_config.num_labels) | |
) | |
elif transformer_head: | |
# Transformer layer head | |
encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=8) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) | |
self.classifier = nn.Linear(config.hidden_size, class_config.num_labels) | |
else: | |
# Default classification head | |
self.classifier = nn.Linear(config.hidden_size, class_config.num_labels) | |
self.post_init() | |
# Model parallel | |
self.model_parallel = False | |
self.device_map = None | |
def parallelize(self, device_map=None): | |
self.device_map = ( | |
get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) | |
if device_map is None | |
else device_map | |
) | |
assert_device_map(self.device_map, len(self.encoder.block)) | |
self.encoder.parallelize(self.device_map) | |
self.classifier = self.classifier.to(self.encoder.first_device) | |
self.model_parallel = True | |
def deparallelize(self): | |
self.encoder.deparallelize() | |
self.encoder = self.encoder.to("cpu") | |
self.model_parallel = False | |
self.device_map = None | |
torch.cuda.empty_cache() | |
def get_input_embeddings(self): | |
return self.shared | |
def set_input_embeddings(self, new_embeddings): | |
self.shared = new_embeddings | |
self.encoder.set_input_embeddings(new_embeddings) | |
def get_encoder(self): | |
return self.encoder | |
def _prune_heads(self, heads_to_prune): | |
for layer, heads in heads_to_prune.items(): | |
self.encoder.layer[layer].attention.prune_heads(heads) | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
outputs = self.encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
inputs_embeds=inputs_embeds, | |
head_mask=head_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
sequence_output = outputs[0] | |
sequence_output = self.dropout(sequence_output) | |
# Forward pass through the selected head | |
if cnn_head: | |
# CNN head | |
sequence_output = sequence_output.permute(0, 2, 1) # Prepare shape for CNN | |
cnn_output = self.cnn(sequence_output) | |
cnn_output = F.relu(cnn_output) | |
cnn_output = cnn_output.permute(0, 2, 1) # Shape back for classifier | |
logits = self.classifier(cnn_output) | |
elif ffn_head: | |
# FFN head | |
logits = self.ffn(sequence_output) | |
elif transformer_head: | |
# Transformer head | |
transformer_output = self.transformer_encoder(sequence_output) | |
logits = self.classifier(transformer_output) | |
else: | |
# Default classification head | |
logits = self.classifier(sequence_output) | |
loss = None | |
if labels is not None: | |
loss_fct = CrossEntropyLoss() | |
active_loss = attention_mask.view(-1) == 1 | |
active_logits = logits.view(-1, self.num_labels) | |
active_labels = torch.where( | |
active_loss, labels.view(-1), torch.tensor(-100).type_as(labels) | |
) | |
valid_logits = active_logits[active_labels != -100] | |
valid_labels = active_labels[active_labels != -100] | |
valid_labels = valid_labels.to(valid_logits.device) | |
valid_labels = valid_labels.long() | |
loss = loss_fct(valid_logits, valid_labels) | |
if not return_dict: | |
output = (logits,) + outputs[2:] | |
return ((loss,) + output) if loss is not None else output | |
return TokenClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
# Modifies an existing transformer and introduce the LoRA layers | |
class CustomLoRAConfig: | |
def __init__(self): | |
self.lora_rank = 4 | |
self.lora_init_scale = 0.01 | |
self.lora_modules = ".*SelfAttention|.*EncDecAttention" | |
self.lora_layers = "q|k|v|o" | |
self.trainable_param_names = ".*layer_norm.*|.*lora_[ab].*" | |
self.lora_scaling_rank = 1 | |
# lora_modules and lora_layers are speicified with regular expressions | |
# see https://www.w3schools.com/python/python_regex.asp for reference | |
class LoRALinear(nn.Module): | |
def __init__(self, linear_layer, rank, scaling_rank, init_scale): | |
super().__init__() | |
self.in_features = linear_layer.in_features | |
self.out_features = linear_layer.out_features | |
self.rank = rank | |
self.scaling_rank = scaling_rank | |
self.weight = linear_layer.weight | |
self.bias = linear_layer.bias | |
if self.rank > 0: | |
self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale) | |
if init_scale < 0: | |
self.lora_b = nn.Parameter(torch.randn(linear_layer.out_features, rank) * init_scale) | |
else: | |
self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank)) | |
if self.scaling_rank: | |
self.multi_lora_a = nn.Parameter( | |
torch.ones(self.scaling_rank, linear_layer.in_features) | |
+ torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale | |
) | |
if init_scale < 0: | |
self.multi_lora_b = nn.Parameter( | |
torch.ones(linear_layer.out_features, self.scaling_rank) | |
+ torch.randn(linear_layer.out_features, self.scaling_rank) * init_scale | |
) | |
else: | |
self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank)) | |
def forward(self, input): | |
if self.scaling_rank == 1 and self.rank == 0: | |
# parsimonious implementation for ia3 and lora scaling | |
if self.multi_lora_a.requires_grad: | |
hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias) | |
else: | |
hidden = F.linear(input, self.weight, self.bias) | |
if self.multi_lora_b.requires_grad: | |
hidden = hidden * self.multi_lora_b.flatten() | |
return hidden | |
else: | |
# general implementation for lora (adding and scaling) | |
weight = self.weight | |
if self.scaling_rank: | |
weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank | |
if self.rank: | |
weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank | |
return F.linear(input, weight, self.bias) | |
def extra_repr(self): | |
return "in_features={}, out_features={}, bias={}, rank={}, scaling_rank={}".format( | |
self.in_features, self.out_features, self.bias is not None, self.rank, self.scaling_rank | |
) | |
def modify_with_lora(transformer, config): | |
for m_name, module in dict(transformer.named_modules()).items(): | |
if re.fullmatch(config.lora_modules, m_name): | |
for c_name, layer in dict(module.named_children()).items(): | |
if re.fullmatch(config.lora_layers, c_name): | |
assert isinstance( | |
layer, nn.Linear | |
), f"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}." | |
setattr( | |
module, | |
c_name, | |
LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale), | |
) | |
return transformer | |
def load_T5_model_classification(checkpoint, num_labels, half_precision, full = False, deepspeed=True): | |
# Load model and tokenizer | |
if "ankh" in checkpoint : | |
model = T5EncoderModel.from_pretrained(checkpoint) | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
elif "prot_t5" in checkpoint: | |
# possible to load the half precision model (thanks to @pawel-rezo for pointing that out) | |
if half_precision and deepspeed: | |
#tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False) | |
#model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16)#.to(torch.device('cuda') | |
tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False) | |
model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16).to(torch.device('cuda')) | |
else: | |
model = T5EncoderModel.from_pretrained(checkpoint) | |
tokenizer = T5Tokenizer.from_pretrained(checkpoint) | |
elif "ProstT5" in checkpoint: | |
if half_precision and deepspeed: | |
tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False) | |
model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16).to(torch.device('cuda')) | |
else: | |
model = T5EncoderModel.from_pretrained(checkpoint) | |
tokenizer = T5Tokenizer.from_pretrained(checkpoint) | |
# Create new Classifier model with PT5 dimensions | |
class_config=ClassConfig(num_labels=num_labels) | |
class_model=T5EncoderForTokenClassification(model.config,class_config) | |
# Set encoder and embedding weights to checkpoint weights | |
class_model.shared=model.shared | |
class_model.encoder=model.encoder | |
# Delete the checkpoint model | |
model=class_model | |
del class_model | |
if full == True: | |
return model, tokenizer | |
# Print number of trainable parameters | |
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
params = sum([np.prod(p.size()) for p in model_parameters]) | |
print("T5_Classfier\nTrainable Parameter: "+ str(params)) | |
if custom_lora: | |
#the linear CustomLoRAConfig allows better quality predictions, but more memory is needed | |
# Add model modification lora | |
config = CustomLoRAConfig() | |
# Add LoRA layers | |
model = modify_with_lora(model, config) | |
# Freeze Embeddings and Encoder (except LoRA) | |
for (param_name, param) in model.shared.named_parameters(): | |
param.requires_grad = False | |
for (param_name, param) in model.encoder.named_parameters(): | |
param.requires_grad = False | |
for (param_name, param) in model.named_parameters(): | |
if re.fullmatch(config.trainable_param_names, param_name): | |
param.requires_grad = True | |
else: | |
# lora modification | |
peft_config = LoraConfig( | |
r=4, lora_alpha=1, bias="all", target_modules=["q","k","v","o"] | |
) | |
model = inject_adapter_in_model(peft_config, model) | |
# Unfreeze the prediction head | |
for (param_name, param) in model.classifier.named_parameters(): | |
param.requires_grad = True | |
# Print trainable Parameter | |
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
params = sum([np.prod(p.size()) for p in model_parameters]) | |
print("T5_LoRA_Classfier\nTrainable Parameter: "+ str(params) + "\n") | |
return model, tokenizer | |
class EsmForTokenClassificationCustom(EsmPreTrainedModel): | |
_keys_to_ignore_on_load_unexpected = [r"pooler"] | |
_keys_to_ignore_on_load_missing = [r"position_ids", r"cnn", r"ffn", r"transformer"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.esm = EsmModel(config, add_pooling_layer=False) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
if cnn_head: | |
self.cnn = nn.Conv1d(config.hidden_size, 512, kernel_size=3, padding=1) | |
self.classifier = nn.Linear(512, config.num_labels) | |
elif ffn_head: | |
# Multi-layer feed-forward network (FFN) as an alternative head | |
self.ffn = nn.Sequential( | |
nn.Linear(config.hidden_size, 512), | |
nn.ReLU(), | |
nn.Linear(512, 256), | |
nn.ReLU(), | |
nn.Linear(256, config.num_labels) | |
) | |
elif transformer_head: | |
# Transformer layer as an alternative head | |
encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=8) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) | |
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
else: | |
# Default classification head | |
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
self.init_weights() | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.Tensor] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, TokenClassifierOutput]: | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
outputs = self.esm( | |
input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
sequence_output = outputs[0] | |
sequence_output = self.dropout(sequence_output) | |
if cnn_head: | |
sequence_output = sequence_output.transpose(1, 2) | |
sequence_output = self.cnn(sequence_output) | |
sequence_output = sequence_output.transpose(1, 2) | |
logits = self.classifier(sequence_output) | |
elif ffn_head: | |
logits = self.ffn(sequence_output) | |
elif transformer_head: | |
# Apply transformer encoder for the transformer head | |
sequence_output = self.transformer_encoder(sequence_output) | |
logits = self.classifier(sequence_output) | |
else: | |
logits = self.classifier(sequence_output) | |
loss = None | |
if labels is not None: | |
loss_fct = CrossEntropyLoss() | |
active_loss = attention_mask.view(-1) == 1 | |
active_logits = logits.view(-1, self.num_labels) | |
active_labels = torch.where( | |
active_loss, labels.view(-1), torch.tensor(-100).type_as(labels) | |
) | |
valid_logits = active_logits[active_labels != -100] | |
valid_labels = active_labels[active_labels != -100] | |
valid_labels = valid_labels.type(torch.LongTensor).to('cuda:0') | |
loss = loss_fct(valid_logits, valid_labels) | |
if not return_dict: | |
output = (logits,) + outputs[2:] | |
return ((loss,) + output) if loss is not None else output | |
return TokenClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
def _init_weights(self, module): | |
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d): | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
# based on transformers DataCollatorForTokenClassification | |
class DataCollatorForTokenClassificationESM(DataCollatorMixin): | |
""" | |
Data collator that will dynamically pad the inputs received, as well as the labels. | |
Args: | |
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): | |
The tokenizer used for encoding the data. | |
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): | |
Select a strategy to pad the returned sequences (according to the model's padding side and padding index) | |
among: | |
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single | |
sequence is provided). | |
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum | |
acceptable input length for the model if that argument is not provided. | |
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). | |
max_length (`int`, *optional*): | |
Maximum length of the returned list and optionally padding length (see above). | |
pad_to_multiple_of (`int`, *optional*): | |
If set will pad the sequence to a multiple of the provided value. | |
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= | |
7.5 (Volta). | |
label_pad_token_id (`int`, *optional*, defaults to -100): | |
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions). | |
return_tensors (`str`): | |
The type of Tensor to return. Allowable values are "np", "pt" and "tf". | |
""" | |
tokenizer: PreTrainedTokenizerBase | |
padding: Union[bool, str, PaddingStrategy] = True | |
max_length: Optional[int] = None | |
pad_to_multiple_of: Optional[int] = None | |
label_pad_token_id: int = -100 | |
return_tensors: str = "pt" | |
def torch_call(self, features): | |
import torch | |
label_name = "label" if "label" in features[0].keys() else "labels" | |
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None | |
no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features] | |
batch = self.tokenizer.pad( | |
no_labels_features, | |
padding=self.padding, | |
max_length=self.max_length, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
return_tensors="pt", | |
) | |
if labels is None: | |
return batch | |
sequence_length = batch["input_ids"].shape[1] | |
padding_side = self.tokenizer.padding_side | |
def to_list(tensor_or_iterable): | |
if isinstance(tensor_or_iterable, torch.Tensor): | |
return tensor_or_iterable.tolist() | |
return list(tensor_or_iterable) | |
if padding_side == "right": | |
batch[label_name] = [ | |
# to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels | |
# changed to pad the special tokens at the beginning and end of the sequence | |
[self.label_pad_token_id] + to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)-1) for label in labels | |
] | |
else: | |
batch[label_name] = [ | |
[self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels | |
] | |
batch[label_name] = torch.tensor(batch[label_name], dtype=torch.float) | |
return batch | |
def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None): | |
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" | |
import torch | |
# Tensorize if necessary. | |
if isinstance(examples[0], (list, tuple, np.ndarray)): | |
examples = [torch.tensor(e, dtype=torch.long) for e in examples] | |
length_of_first = examples[0].size(0) | |
# Check if padding is necessary. | |
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) | |
if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0): | |
return torch.stack(examples, dim=0) | |
# If yes, check if we have a `pad_token`. | |
if tokenizer._pad_token is None: | |
raise ValueError( | |
"You are attempting to pad samples but the tokenizer you are using" | |
f" ({tokenizer.__class__.__name__}) does not have a pad token." | |
) | |
# Creating the full tensor and filling it with our data. | |
max_length = max(x.size(0) for x in examples) | |
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): | |
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of | |
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id) | |
for i, example in enumerate(examples): | |
if tokenizer.padding_side == "right": | |
result[i, : example.shape[0]] = example | |
else: | |
result[i, -example.shape[0] :] = example | |
return result | |
def tolist(x): | |
if isinstance(x, list): | |
return x | |
elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import | |
x = x.numpy() | |
return x.tolist() | |
#load ESM2 models | |
def load_esm_model_classification(checkpoint, num_labels, half_precision, full=False, deepspeed=True): | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
if half_precision and deepspeed: | |
model = EsmForTokenClassificationCustom.from_pretrained(checkpoint, | |
num_labels = num_labels, | |
ignore_mismatched_sizes=True, | |
torch_dtype = torch.float16) | |
else: | |
model = EsmForTokenClassificationCustom.from_pretrained(checkpoint, | |
num_labels = num_labels, | |
ignore_mismatched_sizes=True) | |
if full == True: | |
return model, tokenizer | |
peft_config = LoraConfig( | |
r=4, lora_alpha=1, bias="all", target_modules=["query","key","value","dense"] | |
) | |
model = inject_adapter_in_model(peft_config, model) | |
#model.gradient_checkpointing_enable() | |
# Unfreeze the prediction head | |
for (param_name, param) in model.classifier.named_parameters(): | |
param.requires_grad = True | |
return model, tokenizer | |
def load_model(checkpoint,max_length): | |
#checkpoint='ThorbenF/prot_t5_xl_uniref50' | |
#best_model_path='ThorbenF/prot_t5_xl_uniref50/cpt.pth' | |
full=False | |
deepspeed=False | |
mixed=False | |
num_labels=2 | |
print(checkpoint, num_labels, mixed, full, deepspeed) | |
# Determine model type and load accordingly | |
if "esm" in checkpoint: | |
model, tokenizer = load_esm_model_classification(checkpoint, num_labels, mixed, full, deepspeed) | |
else: | |
model, tokenizer = load_T5_model_classification(checkpoint, num_labels, mixed, full, deepspeed) | |
# Download the file | |
local_file = hf_hub_download(repo_id=checkpoint, filename="cpt.pth") | |
# Load the best model state | |
state_dict = torch.load(local_file, map_location=torch.device('cpu'), weights_only=True) | |
model.load_state_dict(state_dict) | |
return model, tokenizer |