Spaces:
Runtime error
Runtime error
import gc | |
import torch | |
from typing import TYPE_CHECKING, Tuple | |
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList | |
if TYPE_CHECKING: | |
from transformers.modeling_utils import PreTrainedModel | |
class AverageMeter: | |
r""" | |
Computes and stores the average and current value. | |
""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: | |
r""" | |
Returns the number of trainable parameters and number of all parameters in the model. | |
""" | |
trainable_params, all_param = 0, 0 | |
for param in model.parameters(): | |
num_params = param.numel() | |
# if using DS Zero 3 and the weights are initialized empty | |
if num_params == 0 and hasattr(param, "ds_numel"): | |
num_params = param.ds_numel | |
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2 | |
if param.__class__.__name__ == "Params4bit": | |
num_params = num_params * 2 | |
all_param += num_params | |
if param.requires_grad: | |
trainable_params += num_params | |
return trainable_params, all_param | |
def get_logits_processor() -> LogitsProcessorList: | |
logits_processor = LogitsProcessorList() | |
logits_processor.append(InfNanRemoveLogitsProcessor()) | |
return logits_processor | |
def torch_gc() -> None: | |
r""" | |
Collects GPU memory. | |
""" | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": | |
r""" | |
Dispatches a pre-trained model to GPUs with balanced memory. | |
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 | |
""" | |
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing | |
return model | |
if torch.cuda.device_count() > 1: | |
from accelerate import dispatch_model | |
from accelerate.utils import infer_auto_device_map, get_balanced_memory | |
if model._no_split_modules is None: | |
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") | |
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules} | |
max_memory = get_balanced_memory(model, **kwargs) | |
# Make sure tied weights are tied before creating the device map. | |
model.tie_weights() | |
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) | |
return dispatch_model(model, device_map) | |
else: | |
return model.cuda() | |