File size: 4,276 Bytes
5ceacbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import logging
import os
import torch
import torch.distributed as dist
import yaml
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table
from fvcore.nn import flop_count_str
logger = logging.getLogger(__name__)
NORM_MODULES = [
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
]
def register_norm_module(cls):
NORM_MODULES.append(cls)
return cls
def is_main_process():
rank = 0
if 'OMPI_COMM_WORLD_SIZE' in os.environ:
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
return rank == 0
@torch.no_grad()
def analysis_model(model, dump_input, verbose=False):
model.eval()
flops = FlopCountAnalysis(model, dump_input)
total = flops.total()
model.train()
params_total = sum(p.numel() for p in model.parameters())
params_learned = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
logger.info(f"flop count table:\n {flop_count_table(flops)}")
if verbose:
logger.info(f"flop count str:\n {flop_count_str(flops)}")
logger.info(f" Total flops: {total/1000/1000:.3f}M,")
logger.info(f" Total params: {params_total/1000/1000:.3f}M,")
logger.info(f" Learned params: {params_learned/1000/1000:.3f}M")
return total, flop_count_table(flops), flop_count_str(flops)
def load_config_dict_to_opt(opt, config_dict, splitter='.'):
"""
Load the key, value pairs from config_dict to opt, overriding existing values in opt
if there is any.
"""
if not isinstance(config_dict, dict):
raise TypeError("Config must be a Python dictionary")
for k, v in config_dict.items():
k_parts = k.split(splitter)
pointer = opt
for k_part in k_parts[:-1]:
if k_part not in pointer:
pointer[k_part] = {}
pointer = pointer[k_part]
assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict."
ori_value = pointer.get(k_parts[-1])
pointer[k_parts[-1]] = v
if ori_value:
print(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}")
def load_opt_from_config_file(conf_file):
"""
Load opt from the config file.
Args:
conf_file: config file path
Returns:
dict: a dictionary of opt settings
"""
opt = {}
with open(conf_file, encoding='utf-8') as f:
config_dict = yaml.safe_load(f)
load_config_dict_to_opt(opt, config_dict)
return opt
def cast_batch_to_dtype(batch, dtype):
"""
Cast the float32 tensors in a batch to a specified torch dtype.
It should be called before feeding the batch to the FP16 DeepSpeed model.
Args:
batch (torch.tensor or container of torch.tensor): input batch
Returns:
return_batch: same type as the input batch with internal float32 tensors casted to the specified dtype.
"""
if torch.is_tensor(batch):
if torch.is_floating_point(batch):
return_batch = batch.to(dtype)
else:
return_batch = batch
elif isinstance(batch, list):
return_batch = [cast_batch_to_dtype(t, dtype) for t in batch]
elif isinstance(batch, tuple):
return_batch = tuple(cast_batch_to_dtype(t, dtype) for t in batch)
elif isinstance(batch, dict):
return_batch = {}
for k in batch:
return_batch[k] = cast_batch_to_dtype(batch[k], dtype)
else:
logger.debug(f"Can not cast type {type(batch)} to {dtype}. Skipping it in the batch.")
return_batch = batch
return return_batch
def cast_batch_to_half(batch):
"""
Cast the float32 tensors in a batch to float16.
It should be called before feeding the batch to the FP16 DeepSpeed model.
Args:
batch (torch.tensor or container of torch.tensor): input batch
Returns:
return_batch: same type as the input batch with internal float32 tensors casted to float16
"""
return cast_batch_to_dtype(batch, torch.float16)
|