medical
File size: 4,276 Bytes
5ceacbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import logging
import os
import torch
import torch.distributed as dist
import yaml

from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table
from fvcore.nn import flop_count_str


logger = logging.getLogger(__name__)


NORM_MODULES = [
    torch.nn.BatchNorm1d,
    torch.nn.BatchNorm2d,
    torch.nn.BatchNorm3d,
    torch.nn.SyncBatchNorm,
    # NaiveSyncBatchNorm inherits from BatchNorm2d
    torch.nn.GroupNorm,
    torch.nn.InstanceNorm1d,
    torch.nn.InstanceNorm2d,
    torch.nn.InstanceNorm3d,
    torch.nn.LayerNorm,
    torch.nn.LocalResponseNorm,
]

def register_norm_module(cls):
    NORM_MODULES.append(cls)

    return cls


def is_main_process():
    rank = 0
    if 'OMPI_COMM_WORLD_SIZE' in os.environ:
        rank = int(os.environ['OMPI_COMM_WORLD_RANK'])

    return rank == 0


@torch.no_grad()
def analysis_model(model, dump_input, verbose=False):
    model.eval()
    flops = FlopCountAnalysis(model, dump_input)
    total = flops.total()
    model.train()
    params_total = sum(p.numel() for p in model.parameters())
    params_learned = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )
    logger.info(f"flop count table:\n {flop_count_table(flops)}")
    if verbose:
        logger.info(f"flop count str:\n {flop_count_str(flops)}")
    logger.info(f"  Total flops: {total/1000/1000:.3f}M,")
    logger.info(f"  Total params: {params_total/1000/1000:.3f}M,")
    logger.info(f"  Learned params: {params_learned/1000/1000:.3f}M")

    return total, flop_count_table(flops), flop_count_str(flops)


def load_config_dict_to_opt(opt, config_dict, splitter='.'):
    """
    Load the key, value pairs from config_dict to opt, overriding existing values in opt
    if there is any.
    """
    if not isinstance(config_dict, dict):
        raise TypeError("Config must be a Python dictionary")
    for k, v in config_dict.items():
        k_parts = k.split(splitter)
        pointer = opt
        for k_part in k_parts[:-1]:
            if k_part not in pointer:
                pointer[k_part] = {}
            pointer = pointer[k_part]
            assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict."
        ori_value = pointer.get(k_parts[-1])
        pointer[k_parts[-1]] = v
        if ori_value:
            print(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}")


def load_opt_from_config_file(conf_file):
    """
    Load opt from the config file.

    Args:
        conf_file: config file path

    Returns:
        dict: a dictionary of opt settings
    """
    opt = {}
    with open(conf_file, encoding='utf-8') as f:
        config_dict = yaml.safe_load(f)
        load_config_dict_to_opt(opt, config_dict)

    return opt

def cast_batch_to_dtype(batch, dtype):
    """
    Cast the float32 tensors in a batch to a specified torch dtype.
    It should be called before feeding the batch to the FP16 DeepSpeed model.

    Args:
        batch (torch.tensor or container of torch.tensor): input batch
    Returns:
        return_batch: same type as the input batch with internal float32 tensors casted to the specified dtype.
    """
    if torch.is_tensor(batch):
        if torch.is_floating_point(batch):
            return_batch = batch.to(dtype)
        else:
            return_batch = batch
    elif isinstance(batch, list):
        return_batch = [cast_batch_to_dtype(t, dtype) for t in batch]
    elif isinstance(batch, tuple):
        return_batch = tuple(cast_batch_to_dtype(t, dtype) for t in batch)
    elif isinstance(batch, dict):
        return_batch = {}
        for k in batch:
            return_batch[k] = cast_batch_to_dtype(batch[k], dtype)
    else:
        logger.debug(f"Can not cast type {type(batch)} to {dtype}. Skipping it in the batch.")
        return_batch = batch

    return return_batch


def cast_batch_to_half(batch):
    """
    Cast the float32 tensors in a batch to float16.
    It should be called before feeding the batch to the FP16 DeepSpeed model.

    Args:
        batch (torch.tensor or container of torch.tensor): input batch
    Returns:
        return_batch: same type as the input batch with internal float32 tensors casted to float16
    """
    return cast_batch_to_dtype(batch, torch.float16)