# -*- coding: utf-8 -*- # Copyright (c) XiMing Xing. All rights reserved. # Author: XiMing Xing # Description: import sys from collections import OrderedDict import numpy as np import torch layer_modules = (torch.nn.MultiheadAttention,) def summary(model, input_data=None, input_data_args=None, input_shape=None, input_dtype=torch.FloatTensor, batch_size=-1, *args, **kwargs): """ give example input data as least one way like below: ① input_data ---> model.forward(input_data) ② input_data_args ---> model.forward(*input_data_args) ③ input_shape & input_dtype ---> model.forward(*[torch.rand(2, *size).type(input_dtype) for size in input_shape]) """ hooks = [] summary = OrderedDict() def register_hook(module): def hook(module, inputs, outputs): class_name = str(module.__class__).split(".")[-1].split("'")[0] module_idx = len(summary) key = "%s-%i" % (class_name, module_idx + 1) info = OrderedDict() info["id"] = id(module) if isinstance(outputs, (list, tuple)): try: info["out"] = [batch_size] + list(outputs[0].size())[1:] except AttributeError: # pack_padded_seq and pad_packed_seq store feature into data attribute info["out"] = [batch_size] + list(outputs[0].data.size())[1:] else: info["out"] = [batch_size] + list(outputs.size())[1:] info["params_nt"], info["params"] = 0, 0 for name, param in module.named_parameters(): info["params"] += param.nelement() * param.requires_grad info["params_nt"] += param.nelement() * (not param.requires_grad) summary[key] = info # ignore Sequential and ModuleList and other containers if isinstance(module, layer_modules) or not module._modules: hooks.append(module.register_forward_hook(hook)) model.apply(register_hook) # multiple inputs to the network if isinstance(input_shape, tuple): input_shape = [input_shape] if input_data is not None: x = [input_data] elif input_shape is not None: # batch_size of 2 for batchnorm x = [torch.rand(2, *size).type(input_dtype) for size in input_shape] elif input_data_args is not None: x = input_data_args else: x = [] try: with torch.no_grad(): model(*x) if not (kwargs or args) else model(*x, *args, **kwargs) except Exception: # This can be usefull for debugging print("Failed to run summary...") raise finally: for hook in hooks: hook.remove() summary_logs = [] summary_logs.append("--------------------------------------------------------------------------") line_new = "{:<30} {:>20} {:>20}".format("Layer (type)", "Output Shape", "Param #") summary_logs.append(line_new) summary_logs.append("==========================================================================") total_params = 0 total_output = 0 trainable_params = 0 for layer in summary: # layer, output_shape, params line_new = "{:<30} {:>20} {:>20}".format( layer, str(summary[layer]["out"]), "{0:,}".format(summary[layer]["params"] + summary[layer]["params_nt"]) ) total_params += (summary[layer]["params"] + summary[layer]["params_nt"]) total_output += np.prod(summary[layer]["out"]) trainable_params += summary[layer]["params"] summary_logs.append(line_new) # assume 4 bytes/number if input_data is not None: total_input_size = abs(sys.getsizeof(input_data) / (1024 ** 2.)) elif input_shape is not None: total_input_size = abs(np.prod(input_shape) * batch_size * 4. / (1024 ** 2.)) else: total_input_size = 0.0 total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients total_params_size = abs(total_params * 4. / (1024 ** 2.)) total_size = total_params_size + total_output_size + total_input_size summary_logs.append("==========================================================================") summary_logs.append("Total params: {0:,}".format(total_params)) summary_logs.append("Trainable params: {0:,}".format(trainable_params)) summary_logs.append("Non-trainable params: {0:,}".format(total_params - trainable_params)) summary_logs.append("--------------------------------------------------------------------------") summary_logs.append("Input size (MB): %0.6f" % total_input_size) summary_logs.append("Forward/backward pass size (MB): %0.6f" % total_output_size) summary_logs.append("Params size (MB): %0.6f" % total_params_size) summary_logs.append("Estimated Total Size (MB): %0.6f" % total_size) summary_logs.append("--------------------------------------------------------------------------") summary_info = "\n".join(summary_logs) print(summary_info) return summary_info