File size: 3,257 Bytes
98e2ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import omegaconf
from omegaconf import open_dict


def print_model_info(model):
    """Prints model parameters and their total count"""
    total_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            dims = list(param.data.size())
            local_params = 1
            for dim in dims:
                local_params *= dim
            total_params += local_params
            if not ("lm_encoder." in name):
                print(name, param.data.size())
    print("\nTotal Params:{:.2f} (in millions)".format(total_params / 10**6))


def enough_memory():
    if torch.cuda.is_available():
        memory_in_gb = torch.cuda.get_device_properties(0).total_memory // (1024**3)
        if memory_in_gb > 40:
            return True

    return False


def get_sequence_mask(sequence_len):
    """Returns Sequence Mask.
    sequence_len: Tensor of size (B,) with entries indicating length of seq.
    """
    batch_size = sequence_len.size()[0]
    max_len = torch.max(sequence_len)
    tmp = torch.arange(max_len, device=sequence_len.device).expand(batch_size, max_len)
    return tmp < sequence_len.unsqueeze(1)


def get_span_mask(start_ids, end_ids, max_len):
    tmp = (
        torch.arange(max_len, device=start_ids.device)
        .unsqueeze(0)
        .expand(start_ids.shape[0], -1)
    )
    batch_start_ids = start_ids.unsqueeze(1).expand_as(tmp)
    batch_end_ids = end_ids.unsqueeze(1).expand_as(tmp)
    mask = (tmp >= batch_start_ids).float() * (tmp <= batch_end_ids).float()
    return mask


def check_nan_grad(model):
    for name, param in model.named_parameters():
        if param.grad is None:
            continue
        else:
            num_nan = torch.sum(torch.isnan(param.grad.data))
            if num_nan:
                print(name)


def get_l2_norm(model, debug=False):
    total_l2_norm = {"param": 0, "grad": 0}
    param_norm_list = []
    for name, param in model.named_parameters():
        if param.grad is not None:
            param_norm = torch.norm(param.data, p=2)
            if torch.isnan(param_norm):
                print("NaN parameter:", name)
            param_norm_list.append((name, param_norm.item()))
            total_l2_norm["param"] += torch.norm(param.data, p=2).item()
            total_l2_norm["grad"] += torch.norm(param.grad.data, p=2).item()
    if debug:
        print("Summation of L2 norm: %.3f" % total_l2_norm["param"])
        # Sort param list by L2 norm
        sorted_param_list = sorted(param_norm_list, key=lambda x: x[1], reverse=True)
        topk_list = sorted_param_list[:5]
        for name, param_norm in topk_list:
            print(
                "Name: %s\tNorm (%%): %.3f\tNorm: %.3f"
                % (name, param_norm * 100 / total_l2_norm["param"], param_norm)
            )

    return total_l2_norm


def fill_missing_configs(dict1, dict2):
    with open_dict(dict1):
        for key, value in dict2.items():
            if key not in dict1:
                dict1[key] = value
                print(f"Added {key} to config, with value {value}")
            elif isinstance(value, omegaconf.dictconfig.DictConfig):
                dict1[key] = fill_missing_configs(dict1[key], value)
    return dict1