import gc import torch from typing import TYPE_CHECKING, Tuple from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel class AverageMeter: r""" Computes and stores the average and current value. """ def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: r""" Returns the number of trainable parameters and number of all parameters in the model. """ trainable_params, all_param = 0, 0 for param in model.parameters(): num_params = param.numel() # if using DS Zero 3 and the weights are initialized empty if num_params == 0 and hasattr(param, "ds_numel"): num_params = param.ds_numel # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2 if param.__class__.__name__ == "Params4bit": num_params = num_params * 2 all_param += num_params if param.requires_grad: trainable_params += num_params return trainable_params, all_param def get_logits_processor() -> LogitsProcessorList: logits_processor = LogitsProcessorList() logits_processor.append(InfNanRemoveLogitsProcessor()) return logits_processor def torch_gc() -> None: r""" Collects GPU memory. """ gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": r""" Dispatches a pre-trained model to GPUs with balanced memory. Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 """ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing return model if torch.cuda.device_count() > 1: from accelerate import dispatch_model from accelerate.utils import infer_auto_device_map, get_balanced_memory if model._no_split_modules is None: raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules} max_memory = get_balanced_memory(model, **kwargs) # Make sure tied weights are tied before creating the device map. model.tie_weights() device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) return dispatch_model(model, device_map) else: return model.cuda()