File size: 6,329 Bytes
476ac07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Optional, Union

import torch
from mmengine import print_log
from mmengine.hooks import Hook
from mmengine.model.wrappers import is_model_wrapper
from torch.utils._pytree import tree_flatten

from xtuner.parallel.sequence import get_sequence_parallel_world_size

DATA_BATCH = Optional[Union[dict, tuple, list]]


class ThroughputHook(Hook):

    # priority must be higher than LoggerHook (50) and lower than
    # IterTimerHook (60)
    priority = 55

    def __init__(self,
                 use_activation_checkpointing=None,
                 hidden_size=None,
                 num_layers=None,
                 vocab_size=None,
                 mlp_ratio=None,
                 is_casual=None):
        self.use_activation_checkpointing = use_activation_checkpointing
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        self.mlp_ratio = mlp_ratio
        self.is_casual = is_casual

    @staticmethod
    def _guess_is_casual_attn(model):
        for module in model.modules():
            if hasattr(module, 'is_causal'):
                return module.is_causal
        print_log(
            'It\'s impossible to speculate whether casual attention was used, '
            'and FLOPs will be calculated as `casual = True`.', 'current')
        return True

    @staticmethod
    def _get_batch_size_and_sequence_len(data_batch):
        data_list, _ = tree_flatten(data_batch)
        for data in data_list:
            if isinstance(data, torch.Tensor):
                return data.size(0), data.size(1)
        raise RuntimeError('No tensor found in the batch')

    @staticmethod
    def _guess_use_activation_checkpointing(model):
        for module in model.modules():
            if hasattr(module, 'gradient_checkpointing'):
                return module.gradient_checkpointing
        return False

    def before_run(self, runner) -> None:
        if is_model_wrapper(runner.model):
            model = runner.model.module
        else:
            model = runner.model
        self.use_activation_checkpointing = \
            (self.use_activation_checkpointing or
             self._guess_use_activation_checkpointing(model))
        self.hidden_size = self.hidden_size or model.config.hidden_size
        self.num_layers = self.num_layers or model.config.num_hidden_layers
        self.vocab_size = self.vocab_size or model.config.vocab_size
        self.mlp_ratio = self.mlp_ratio or (model.config.intermediate_size /
                                            model.config.hidden_size)
        self.mlp_ratio *= 1.5  # has gate_proj
        self.is_casual = self.is_casual if self.is_casual is not None \
            else self._guess_is_casual_attn(model)

        use_varlen_attn = getattr(model, 'use_varlen_attn', False)
        if use_varlen_attn:
            print_log(
                'Using variable-length Flash Attention causes an inflation'
                ' in the FLOPs calculation.',
                'current',
                level=logging.WARNING)

        return

    def after_train_iter(self,
                         runner,
                         batch_idx: int,
                         data_batch: DATA_BATCH = None,
                         outputs: Optional[dict] = None) -> None:
        """Calc flops based on the paper of Megatron
        https://deepakn94.github.io/assets/papers/megatron-sc21.pdf."""

        batch_size, sequence_len = self._get_batch_size_and_sequence_len(
            data_batch)
        sequence_parallel_size = get_sequence_parallel_world_size()
        sequence_len /= sequence_parallel_size

        message_hub = runner.message_hub
        iter_time = message_hub.get_scalar('train/time').current()

        # We consider a language model with 𝑙 transformer layers,
        # hidden size h, sequence length s, vocabulary size V, and
        # training batch size B.
        # A $A_{mxk}$ x $X_{kxn}$ matrix multiplication requires 2π‘š Γ—π‘˜ ×𝑛 FLOPs
        # (factor of 2 needed to account for multiplies and adds).

        # Attention Layer:
        # qkv_proj + o_proj: 8B * s * h^2
        # attn: 2B * s^2 * h (casual=False) and 2B * s^2 * h / 2 (casual=True)

        # MLP Layer:
        # up_proj + down_proj + gate_proj: 4B * s * h^2 * mlp_ratio
        # (In Llama mlp_ratio = intermediate_size / hidden_size * 1.5
        # (has gate_proj))

        # The backward pass requires double the number of FLOPs since we
        # need to calculate the gradients with respect to both input and
        # weight tensors. In addition, we are using activation recomputation,
        # which requires an additional forward pass before the backward pass.

        # While sequence parallel will affect the FLOPs calculation in attn.
        # Suppose the sequence length in one GPU is s and the sequence
        # parallel world size is `sp_size`, which means the total
        # sequence length in the attention calculation is
        # `s * sp_size` and the number of attention heads decrease to
        # `num_heads / sp_size`. Hence, the FLOPs in attn calculation is:
        # 2B * (s * sp_size)^2 * (h / sp_size) (casual=False) and
        # 2B * (s * sp_size)^2 * (h / sp_size) / 2 (casual=True)

        flops_qkvo_proj = 8 * batch_size * sequence_len * self.hidden_size**2
        flops_attn = 4 * batch_size * sequence_len**2 * self.hidden_size * \
            sequence_parallel_size / (int(self.is_casual) + 1)
        flops_mlp = 4 * self.mlp_ratio * batch_size * sequence_len * \
            self.hidden_size**2
        flops_wo_head = (3 + int(self.use_activation_checkpointing)) * (
            flops_qkvo_proj + flops_attn + flops_mlp) * self.num_layers
        flops_head = 3 * 2 * batch_size * sequence_len * self.hidden_size * \
            self.vocab_size
        flops_per_iteration = flops_wo_head + flops_head

        avg_tflops_per_gpu = flops_per_iteration / 1e12 / (iter_time + 1e-12)
        tokens_per_sec_per_gpu = batch_size * sequence_len / (
            iter_time + 1e-12)

        message_hub.update_scalar('train/tflops', avg_tflops_per_gpu)
        message_hub.update_scalar('train/tokens_per_sec',
                                  tokens_per_sec_per_gpu)