MEIRa / pytorch_utils /utils.py
KawshikManikantan's picture
upload_trial
98e2ea5
raw
history blame
3.26 kB
import torch
import omegaconf
from omegaconf import open_dict
def print_model_info(model):
"""Prints model parameters and their total count"""
total_params = 0
for name, param in model.named_parameters():
if param.requires_grad:
dims = list(param.data.size())
local_params = 1
for dim in dims:
local_params *= dim
total_params += local_params
if not ("lm_encoder." in name):
print(name, param.data.size())
print("\nTotal Params:{:.2f} (in millions)".format(total_params / 10**6))
def enough_memory():
if torch.cuda.is_available():
memory_in_gb = torch.cuda.get_device_properties(0).total_memory // (1024**3)
if memory_in_gb > 40:
return True
return False
def get_sequence_mask(sequence_len):
"""Returns Sequence Mask.
sequence_len: Tensor of size (B,) with entries indicating length of seq.
"""
batch_size = sequence_len.size()[0]
max_len = torch.max(sequence_len)
tmp = torch.arange(max_len, device=sequence_len.device).expand(batch_size, max_len)
return tmp < sequence_len.unsqueeze(1)
def get_span_mask(start_ids, end_ids, max_len):
tmp = (
torch.arange(max_len, device=start_ids.device)
.unsqueeze(0)
.expand(start_ids.shape[0], -1)
)
batch_start_ids = start_ids.unsqueeze(1).expand_as(tmp)
batch_end_ids = end_ids.unsqueeze(1).expand_as(tmp)
mask = (tmp >= batch_start_ids).float() * (tmp <= batch_end_ids).float()
return mask
def check_nan_grad(model):
for name, param in model.named_parameters():
if param.grad is None:
continue
else:
num_nan = torch.sum(torch.isnan(param.grad.data))
if num_nan:
print(name)
def get_l2_norm(model, debug=False):
total_l2_norm = {"param": 0, "grad": 0}
param_norm_list = []
for name, param in model.named_parameters():
if param.grad is not None:
param_norm = torch.norm(param.data, p=2)
if torch.isnan(param_norm):
print("NaN parameter:", name)
param_norm_list.append((name, param_norm.item()))
total_l2_norm["param"] += torch.norm(param.data, p=2).item()
total_l2_norm["grad"] += torch.norm(param.grad.data, p=2).item()
if debug:
print("Summation of L2 norm: %.3f" % total_l2_norm["param"])
# Sort param list by L2 norm
sorted_param_list = sorted(param_norm_list, key=lambda x: x[1], reverse=True)
topk_list = sorted_param_list[:5]
for name, param_norm in topk_list:
print(
"Name: %s\tNorm (%%): %.3f\tNorm: %.3f"
% (name, param_norm * 100 / total_l2_norm["param"], param_norm)
)
return total_l2_norm
def fill_missing_configs(dict1, dict2):
with open_dict(dict1):
for key, value in dict2.items():
if key not in dict1:
dict1[key] = value
print(f"Added {key} to config, with value {value}")
elif isinstance(value, omegaconf.dictconfig.DictConfig):
dict1[key] = fill_missing_configs(dict1[key], value)
return dict1