hjc-owo
init repo
966ae59
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
import sys
from collections import OrderedDict
import numpy as np
import torch
layer_modules = (torch.nn.MultiheadAttention,)
def summary(model, input_data=None, input_data_args=None, input_shape=None, input_dtype=torch.FloatTensor,
batch_size=-1,
*args, **kwargs):
"""
give example input data as least one way like below:
① input_data ---> model.forward(input_data)
② input_data_args ---> model.forward(*input_data_args)
③ input_shape & input_dtype ---> model.forward(*[torch.rand(2, *size).type(input_dtype) for size in input_shape])
"""
hooks = []
summary = OrderedDict()
def register_hook(module):
def hook(module, inputs, outputs):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
module_idx = len(summary)
key = "%s-%i" % (class_name, module_idx + 1)
info = OrderedDict()
info["id"] = id(module)
if isinstance(outputs, (list, tuple)):
try:
info["out"] = [batch_size] + list(outputs[0].size())[1:]
except AttributeError:
# pack_padded_seq and pad_packed_seq store feature into data attribute
info["out"] = [batch_size] + list(outputs[0].data.size())[1:]
else:
info["out"] = [batch_size] + list(outputs.size())[1:]
info["params_nt"], info["params"] = 0, 0
for name, param in module.named_parameters():
info["params"] += param.nelement() * param.requires_grad
info["params_nt"] += param.nelement() * (not param.requires_grad)
summary[key] = info
# ignore Sequential and ModuleList and other containers
if isinstance(module, layer_modules) or not module._modules:
hooks.append(module.register_forward_hook(hook))
model.apply(register_hook)
# multiple inputs to the network
if isinstance(input_shape, tuple):
input_shape = [input_shape]
if input_data is not None:
x = [input_data]
elif input_shape is not None:
# batch_size of 2 for batchnorm
x = [torch.rand(2, *size).type(input_dtype) for size in input_shape]
elif input_data_args is not None:
x = input_data_args
else:
x = []
try:
with torch.no_grad():
model(*x) if not (kwargs or args) else model(*x, *args, **kwargs)
except Exception:
# This can be usefull for debugging
print("Failed to run summary...")
raise
finally:
for hook in hooks:
hook.remove()
summary_logs = []
summary_logs.append("--------------------------------------------------------------------------")
line_new = "{:<30} {:>20} {:>20}".format("Layer (type)", "Output Shape", "Param #")
summary_logs.append(line_new)
summary_logs.append("==========================================================================")
total_params = 0
total_output = 0
trainable_params = 0
for layer in summary:
# layer, output_shape, params
line_new = "{:<30} {:>20} {:>20}".format(
layer,
str(summary[layer]["out"]),
"{0:,}".format(summary[layer]["params"] + summary[layer]["params_nt"])
)
total_params += (summary[layer]["params"] + summary[layer]["params_nt"])
total_output += np.prod(summary[layer]["out"])
trainable_params += summary[layer]["params"]
summary_logs.append(line_new)
# assume 4 bytes/number
if input_data is not None:
total_input_size = abs(sys.getsizeof(input_data) / (1024 ** 2.))
elif input_shape is not None:
total_input_size = abs(np.prod(input_shape) * batch_size * 4. / (1024 ** 2.))
else:
total_input_size = 0.0
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
total_params_size = abs(total_params * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size
summary_logs.append("==========================================================================")
summary_logs.append("Total params: {0:,}".format(total_params))
summary_logs.append("Trainable params: {0:,}".format(trainable_params))
summary_logs.append("Non-trainable params: {0:,}".format(total_params - trainable_params))
summary_logs.append("--------------------------------------------------------------------------")
summary_logs.append("Input size (MB): %0.6f" % total_input_size)
summary_logs.append("Forward/backward pass size (MB): %0.6f" % total_output_size)
summary_logs.append("Params size (MB): %0.6f" % total_params_size)
summary_logs.append("Estimated Total Size (MB): %0.6f" % total_size)
summary_logs.append("--------------------------------------------------------------------------")
summary_info = "\n".join(summary_logs)
print(summary_info)
return summary_info