Spaces:
Running
Running
File size: 5,130 Bytes
966ae59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
import sys
from collections import OrderedDict
import numpy as np
import torch
layer_modules = (torch.nn.MultiheadAttention,)
def summary(model, input_data=None, input_data_args=None, input_shape=None, input_dtype=torch.FloatTensor,
batch_size=-1,
*args, **kwargs):
"""
give example input data as least one way like below:
① input_data ---> model.forward(input_data)
② input_data_args ---> model.forward(*input_data_args)
③ input_shape & input_dtype ---> model.forward(*[torch.rand(2, *size).type(input_dtype) for size in input_shape])
"""
hooks = []
summary = OrderedDict()
def register_hook(module):
def hook(module, inputs, outputs):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
module_idx = len(summary)
key = "%s-%i" % (class_name, module_idx + 1)
info = OrderedDict()
info["id"] = id(module)
if isinstance(outputs, (list, tuple)):
try:
info["out"] = [batch_size] + list(outputs[0].size())[1:]
except AttributeError:
# pack_padded_seq and pad_packed_seq store feature into data attribute
info["out"] = [batch_size] + list(outputs[0].data.size())[1:]
else:
info["out"] = [batch_size] + list(outputs.size())[1:]
info["params_nt"], info["params"] = 0, 0
for name, param in module.named_parameters():
info["params"] += param.nelement() * param.requires_grad
info["params_nt"] += param.nelement() * (not param.requires_grad)
summary[key] = info
# ignore Sequential and ModuleList and other containers
if isinstance(module, layer_modules) or not module._modules:
hooks.append(module.register_forward_hook(hook))
model.apply(register_hook)
# multiple inputs to the network
if isinstance(input_shape, tuple):
input_shape = [input_shape]
if input_data is not None:
x = [input_data]
elif input_shape is not None:
# batch_size of 2 for batchnorm
x = [torch.rand(2, *size).type(input_dtype) for size in input_shape]
elif input_data_args is not None:
x = input_data_args
else:
x = []
try:
with torch.no_grad():
model(*x) if not (kwargs or args) else model(*x, *args, **kwargs)
except Exception:
# This can be usefull for debugging
print("Failed to run summary...")
raise
finally:
for hook in hooks:
hook.remove()
summary_logs = []
summary_logs.append("--------------------------------------------------------------------------")
line_new = "{:<30} {:>20} {:>20}".format("Layer (type)", "Output Shape", "Param #")
summary_logs.append(line_new)
summary_logs.append("==========================================================================")
total_params = 0
total_output = 0
trainable_params = 0
for layer in summary:
# layer, output_shape, params
line_new = "{:<30} {:>20} {:>20}".format(
layer,
str(summary[layer]["out"]),
"{0:,}".format(summary[layer]["params"] + summary[layer]["params_nt"])
)
total_params += (summary[layer]["params"] + summary[layer]["params_nt"])
total_output += np.prod(summary[layer]["out"])
trainable_params += summary[layer]["params"]
summary_logs.append(line_new)
# assume 4 bytes/number
if input_data is not None:
total_input_size = abs(sys.getsizeof(input_data) / (1024 ** 2.))
elif input_shape is not None:
total_input_size = abs(np.prod(input_shape) * batch_size * 4. / (1024 ** 2.))
else:
total_input_size = 0.0
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
total_params_size = abs(total_params * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size
summary_logs.append("==========================================================================")
summary_logs.append("Total params: {0:,}".format(total_params))
summary_logs.append("Trainable params: {0:,}".format(trainable_params))
summary_logs.append("Non-trainable params: {0:,}".format(total_params - trainable_params))
summary_logs.append("--------------------------------------------------------------------------")
summary_logs.append("Input size (MB): %0.6f" % total_input_size)
summary_logs.append("Forward/backward pass size (MB): %0.6f" % total_output_size)
summary_logs.append("Params size (MB): %0.6f" % total_params_size)
summary_logs.append("Estimated Total Size (MB): %0.6f" % total_size)
summary_logs.append("--------------------------------------------------------------------------")
summary_info = "\n".join(summary_logs)
print(summary_info)
return summary_info
|