Spaces:
Runtime error
Runtime error
File size: 6,329 Bytes
476ac07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Optional, Union
import torch
from mmengine import print_log
from mmengine.hooks import Hook
from mmengine.model.wrappers import is_model_wrapper
from torch.utils._pytree import tree_flatten
from xtuner.parallel.sequence import get_sequence_parallel_world_size
DATA_BATCH = Optional[Union[dict, tuple, list]]
class ThroughputHook(Hook):
# priority must be higher than LoggerHook (50) and lower than
# IterTimerHook (60)
priority = 55
def __init__(self,
use_activation_checkpointing=None,
hidden_size=None,
num_layers=None,
vocab_size=None,
mlp_ratio=None,
is_casual=None):
self.use_activation_checkpointing = use_activation_checkpointing
self.hidden_size = hidden_size
self.num_layers = num_layers
self.vocab_size = vocab_size
self.mlp_ratio = mlp_ratio
self.is_casual = is_casual
@staticmethod
def _guess_is_casual_attn(model):
for module in model.modules():
if hasattr(module, 'is_causal'):
return module.is_causal
print_log(
'It\'s impossible to speculate whether casual attention was used, '
'and FLOPs will be calculated as `casual = True`.', 'current')
return True
@staticmethod
def _get_batch_size_and_sequence_len(data_batch):
data_list, _ = tree_flatten(data_batch)
for data in data_list:
if isinstance(data, torch.Tensor):
return data.size(0), data.size(1)
raise RuntimeError('No tensor found in the batch')
@staticmethod
def _guess_use_activation_checkpointing(model):
for module in model.modules():
if hasattr(module, 'gradient_checkpointing'):
return module.gradient_checkpointing
return False
def before_run(self, runner) -> None:
if is_model_wrapper(runner.model):
model = runner.model.module
else:
model = runner.model
self.use_activation_checkpointing = \
(self.use_activation_checkpointing or
self._guess_use_activation_checkpointing(model))
self.hidden_size = self.hidden_size or model.config.hidden_size
self.num_layers = self.num_layers or model.config.num_hidden_layers
self.vocab_size = self.vocab_size or model.config.vocab_size
self.mlp_ratio = self.mlp_ratio or (model.config.intermediate_size /
model.config.hidden_size)
self.mlp_ratio *= 1.5 # has gate_proj
self.is_casual = self.is_casual if self.is_casual is not None \
else self._guess_is_casual_attn(model)
use_varlen_attn = getattr(model, 'use_varlen_attn', False)
if use_varlen_attn:
print_log(
'Using variable-length Flash Attention causes an inflation'
' in the FLOPs calculation.',
'current',
level=logging.WARNING)
return
def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Calc flops based on the paper of Megatron
https://deepakn94.github.io/assets/papers/megatron-sc21.pdf."""
batch_size, sequence_len = self._get_batch_size_and_sequence_len(
data_batch)
sequence_parallel_size = get_sequence_parallel_world_size()
sequence_len /= sequence_parallel_size
message_hub = runner.message_hub
iter_time = message_hub.get_scalar('train/time').current()
# We consider a language model with π transformer layers,
# hidden size h, sequence length s, vocabulary size V, and
# training batch size B.
# A $A_{mxk}$ x $X_{kxn}$ matrix multiplication requires 2π Γπ Γπ FLOPs
# (factor of 2 needed to account for multiplies and adds).
# Attention Layer:
# qkv_proj + o_proj: 8B * s * h^2
# attn: 2B * s^2 * h (casual=False) and 2B * s^2 * h / 2 (casual=True)
# MLP Layer:
# up_proj + down_proj + gate_proj: 4B * s * h^2 * mlp_ratio
# (In Llama mlp_ratio = intermediate_size / hidden_size * 1.5
# (has gate_proj))
# The backward pass requires double the number of FLOPs since we
# need to calculate the gradients with respect to both input and
# weight tensors. In addition, we are using activation recomputation,
# which requires an additional forward pass before the backward pass.
# While sequence parallel will affect the FLOPs calculation in attn.
# Suppose the sequence length in one GPU is s and the sequence
# parallel world size is `sp_size`, which means the total
# sequence length in the attention calculation is
# `s * sp_size` and the number of attention heads decrease to
# `num_heads / sp_size`. Hence, the FLOPs in attn calculation is:
# 2B * (s * sp_size)^2 * (h / sp_size) (casual=False) and
# 2B * (s * sp_size)^2 * (h / sp_size) / 2 (casual=True)
flops_qkvo_proj = 8 * batch_size * sequence_len * self.hidden_size**2
flops_attn = 4 * batch_size * sequence_len**2 * self.hidden_size * \
sequence_parallel_size / (int(self.is_casual) + 1)
flops_mlp = 4 * self.mlp_ratio * batch_size * sequence_len * \
self.hidden_size**2
flops_wo_head = (3 + int(self.use_activation_checkpointing)) * (
flops_qkvo_proj + flops_attn + flops_mlp) * self.num_layers
flops_head = 3 * 2 * batch_size * sequence_len * self.hidden_size * \
self.vocab_size
flops_per_iteration = flops_wo_head + flops_head
avg_tflops_per_gpu = flops_per_iteration / 1e12 / (iter_time + 1e-12)
tokens_per_sec_per_gpu = batch_size * sequence_len / (
iter_time + 1e-12)
message_hub.update_scalar('train/tflops', avg_tflops_per_gpu)
message_hub.update_scalar('train/tokens_per_sec',
tokens_per_sec_per_gpu)
|