File size: 5,130 Bytes
966ae59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:

import sys
from collections import OrderedDict

import numpy as np
import torch

layer_modules = (torch.nn.MultiheadAttention,)


def summary(model, input_data=None, input_data_args=None, input_shape=None, input_dtype=torch.FloatTensor,
            batch_size=-1,
            *args, **kwargs):
    """
    give example input data as least one way like below:
    ① input_data ---> model.forward(input_data)
    ② input_data_args ---> model.forward(*input_data_args)
    ③ input_shape & input_dtype ---> model.forward(*[torch.rand(2, *size).type(input_dtype) for size in input_shape])
    """

    hooks = []
    summary = OrderedDict()

    def register_hook(module):
        def hook(module, inputs, outputs):

            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)

            key = "%s-%i" % (class_name, module_idx + 1)

            info = OrderedDict()
            info["id"] = id(module)
            if isinstance(outputs, (list, tuple)):
                try:
                    info["out"] = [batch_size] + list(outputs[0].size())[1:]
                except AttributeError:
                    # pack_padded_seq and pad_packed_seq store feature into data attribute
                    info["out"] = [batch_size] + list(outputs[0].data.size())[1:]
            else:
                info["out"] = [batch_size] + list(outputs.size())[1:]

            info["params_nt"], info["params"] = 0, 0
            for name, param in module.named_parameters():
                info["params"] += param.nelement() * param.requires_grad
                info["params_nt"] += param.nelement() * (not param.requires_grad)

            summary[key] = info

        # ignore Sequential and ModuleList and other containers
        if isinstance(module, layer_modules) or not module._modules:
            hooks.append(module.register_forward_hook(hook))

    model.apply(register_hook)

    # multiple inputs to the network
    if isinstance(input_shape, tuple):
        input_shape = [input_shape]

    if input_data is not None:
        x = [input_data]
    elif input_shape is not None:
        # batch_size of 2 for batchnorm
        x = [torch.rand(2, *size).type(input_dtype) for size in input_shape]
    elif input_data_args is not None:
        x = input_data_args
    else:
        x = []
    try:
        with torch.no_grad():
            model(*x) if not (kwargs or args) else model(*x, *args, **kwargs)
    except Exception:
        # This can be usefull for debugging
        print("Failed to run summary...")
        raise
    finally:
        for hook in hooks:
            hook.remove()
    summary_logs = []
    summary_logs.append("--------------------------------------------------------------------------")
    line_new = "{:<30}  {:>20} {:>20}".format("Layer (type)", "Output Shape", "Param #")
    summary_logs.append(line_new)
    summary_logs.append("==========================================================================")
    total_params = 0
    total_output = 0
    trainable_params = 0
    for layer in summary:
        # layer, output_shape, params
        line_new = "{:<30}  {:>20} {:>20}".format(
            layer,
            str(summary[layer]["out"]),
            "{0:,}".format(summary[layer]["params"] + summary[layer]["params_nt"])
        )
        total_params += (summary[layer]["params"] + summary[layer]["params_nt"])
        total_output += np.prod(summary[layer]["out"])
        trainable_params += summary[layer]["params"]
        summary_logs.append(line_new)

    # assume 4 bytes/number
    if input_data is not None:
        total_input_size = abs(sys.getsizeof(input_data) / (1024 ** 2.))
    elif input_shape is not None:
        total_input_size = abs(np.prod(input_shape) * batch_size * 4. / (1024 ** 2.))
    else:
        total_input_size = 0.0
    total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
    total_params_size = abs(total_params * 4. / (1024 ** 2.))
    total_size = total_params_size + total_output_size + total_input_size

    summary_logs.append("==========================================================================")
    summary_logs.append("Total params: {0:,}".format(total_params))
    summary_logs.append("Trainable params: {0:,}".format(trainable_params))
    summary_logs.append("Non-trainable params: {0:,}".format(total_params - trainable_params))
    summary_logs.append("--------------------------------------------------------------------------")
    summary_logs.append("Input size (MB): %0.6f" % total_input_size)
    summary_logs.append("Forward/backward pass size (MB): %0.6f" % total_output_size)
    summary_logs.append("Params size (MB): %0.6f" % total_params_size)
    summary_logs.append("Estimated Total Size (MB): %0.6f" % total_size)
    summary_logs.append("--------------------------------------------------------------------------")

    summary_info = "\n".join(summary_logs)

    print(summary_info)
    return summary_info