import torch import omegaconf from omegaconf import open_dict def print_model_info(model): """Prints model parameters and their total count""" total_params = 0 for name, param in model.named_parameters(): if param.requires_grad: dims = list(param.data.size()) local_params = 1 for dim in dims: local_params *= dim total_params += local_params if not ("lm_encoder." in name): print(name, param.data.size()) print("\nTotal Params:{:.2f} (in millions)".format(total_params / 10**6)) def enough_memory(): if torch.cuda.is_available(): memory_in_gb = torch.cuda.get_device_properties(0).total_memory // (1024**3) if memory_in_gb > 40: return True return False def get_sequence_mask(sequence_len): """Returns Sequence Mask. sequence_len: Tensor of size (B,) with entries indicating length of seq. """ batch_size = sequence_len.size()[0] max_len = torch.max(sequence_len) tmp = torch.arange(max_len, device=sequence_len.device).expand(batch_size, max_len) return tmp < sequence_len.unsqueeze(1) def get_span_mask(start_ids, end_ids, max_len): tmp = ( torch.arange(max_len, device=start_ids.device) .unsqueeze(0) .expand(start_ids.shape[0], -1) ) batch_start_ids = start_ids.unsqueeze(1).expand_as(tmp) batch_end_ids = end_ids.unsqueeze(1).expand_as(tmp) mask = (tmp >= batch_start_ids).float() * (tmp <= batch_end_ids).float() return mask def check_nan_grad(model): for name, param in model.named_parameters(): if param.grad is None: continue else: num_nan = torch.sum(torch.isnan(param.grad.data)) if num_nan: print(name) def get_l2_norm(model, debug=False): total_l2_norm = {"param": 0, "grad": 0} param_norm_list = [] for name, param in model.named_parameters(): if param.grad is not None: param_norm = torch.norm(param.data, p=2) if torch.isnan(param_norm): print("NaN parameter:", name) param_norm_list.append((name, param_norm.item())) total_l2_norm["param"] += torch.norm(param.data, p=2).item() total_l2_norm["grad"] += torch.norm(param.grad.data, p=2).item() if debug: print("Summation of L2 norm: %.3f" % total_l2_norm["param"]) # Sort param list by L2 norm sorted_param_list = sorted(param_norm_list, key=lambda x: x[1], reverse=True) topk_list = sorted_param_list[:5] for name, param_norm in topk_list: print( "Name: %s\tNorm (%%): %.3f\tNorm: %.3f" % (name, param_norm * 100 / total_l2_norm["param"], param_norm) ) return total_l2_norm def fill_missing_configs(dict1, dict2): with open_dict(dict1): for key, value in dict2.items(): if key not in dict1: dict1[key] = value print(f"Added {key} to config, with value {value}") elif isinstance(value, omegaconf.dictconfig.DictConfig): dict1[key] = fill_missing_configs(dict1[key], value) return dict1