OMG-LLaVA / xtuner /engine /hooks /throughput_hook.py
zhangtao-whu's picture
Upload folder using huggingface_hub
476ac07 verified
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Optional, Union
import torch
from mmengine import print_log
from mmengine.hooks import Hook
from mmengine.model.wrappers import is_model_wrapper
from torch.utils._pytree import tree_flatten
from xtuner.parallel.sequence import get_sequence_parallel_world_size
DATA_BATCH = Optional[Union[dict, tuple, list]]
class ThroughputHook(Hook):
# priority must be higher than LoggerHook (50) and lower than
# IterTimerHook (60)
priority = 55
def __init__(self,
use_activation_checkpointing=None,
hidden_size=None,
num_layers=None,
vocab_size=None,
mlp_ratio=None,
is_casual=None):
self.use_activation_checkpointing = use_activation_checkpointing
self.hidden_size = hidden_size
self.num_layers = num_layers
self.vocab_size = vocab_size
self.mlp_ratio = mlp_ratio
self.is_casual = is_casual
@staticmethod
def _guess_is_casual_attn(model):
for module in model.modules():
if hasattr(module, 'is_causal'):
return module.is_causal
print_log(
'It\'s impossible to speculate whether casual attention was used, '
'and FLOPs will be calculated as `casual = True`.', 'current')
return True
@staticmethod
def _get_batch_size_and_sequence_len(data_batch):
data_list, _ = tree_flatten(data_batch)
for data in data_list:
if isinstance(data, torch.Tensor):
return data.size(0), data.size(1)
raise RuntimeError('No tensor found in the batch')
@staticmethod
def _guess_use_activation_checkpointing(model):
for module in model.modules():
if hasattr(module, 'gradient_checkpointing'):
return module.gradient_checkpointing
return False
def before_run(self, runner) -> None:
if is_model_wrapper(runner.model):
model = runner.model.module
else:
model = runner.model
self.use_activation_checkpointing = \
(self.use_activation_checkpointing or
self._guess_use_activation_checkpointing(model))
self.hidden_size = self.hidden_size or model.config.hidden_size
self.num_layers = self.num_layers or model.config.num_hidden_layers
self.vocab_size = self.vocab_size or model.config.vocab_size
self.mlp_ratio = self.mlp_ratio or (model.config.intermediate_size /
model.config.hidden_size)
self.mlp_ratio *= 1.5 # has gate_proj
self.is_casual = self.is_casual if self.is_casual is not None \
else self._guess_is_casual_attn(model)
use_varlen_attn = getattr(model, 'use_varlen_attn', False)
if use_varlen_attn:
print_log(
'Using variable-length Flash Attention causes an inflation'
' in the FLOPs calculation.',
'current',
level=logging.WARNING)
return
def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Calc flops based on the paper of Megatron
https://deepakn94.github.io/assets/papers/megatron-sc21.pdf."""
batch_size, sequence_len = self._get_batch_size_and_sequence_len(
data_batch)
sequence_parallel_size = get_sequence_parallel_world_size()
sequence_len /= sequence_parallel_size
message_hub = runner.message_hub
iter_time = message_hub.get_scalar('train/time').current()
# We consider a language model with 𝑙 transformer layers,
# hidden size h, sequence length s, vocabulary size V, and
# training batch size B.
# A $A_{mxk}$ x $X_{kxn}$ matrix multiplication requires 2π‘š Γ—π‘˜ ×𝑛 FLOPs
# (factor of 2 needed to account for multiplies and adds).
# Attention Layer:
# qkv_proj + o_proj: 8B * s * h^2
# attn: 2B * s^2 * h (casual=False) and 2B * s^2 * h / 2 (casual=True)
# MLP Layer:
# up_proj + down_proj + gate_proj: 4B * s * h^2 * mlp_ratio
# (In Llama mlp_ratio = intermediate_size / hidden_size * 1.5
# (has gate_proj))
# The backward pass requires double the number of FLOPs since we
# need to calculate the gradients with respect to both input and
# weight tensors. In addition, we are using activation recomputation,
# which requires an additional forward pass before the backward pass.
# While sequence parallel will affect the FLOPs calculation in attn.
# Suppose the sequence length in one GPU is s and the sequence
# parallel world size is `sp_size`, which means the total
# sequence length in the attention calculation is
# `s * sp_size` and the number of attention heads decrease to
# `num_heads / sp_size`. Hence, the FLOPs in attn calculation is:
# 2B * (s * sp_size)^2 * (h / sp_size) (casual=False) and
# 2B * (s * sp_size)^2 * (h / sp_size) / 2 (casual=True)
flops_qkvo_proj = 8 * batch_size * sequence_len * self.hidden_size**2
flops_attn = 4 * batch_size * sequence_len**2 * self.hidden_size * \
sequence_parallel_size / (int(self.is_casual) + 1)
flops_mlp = 4 * self.mlp_ratio * batch_size * sequence_len * \
self.hidden_size**2
flops_wo_head = (3 + int(self.use_activation_checkpointing)) * (
flops_qkvo_proj + flops_attn + flops_mlp) * self.num_layers
flops_head = 3 * 2 * batch_size * sequence_len * self.hidden_size * \
self.vocab_size
flops_per_iteration = flops_wo_head + flops_head
avg_tflops_per_gpu = flops_per_iteration / 1e12 / (iter_time + 1e-12)
tokens_per_sec_per_gpu = batch_size * sequence_len / (
iter_time + 1e-12)
message_hub.update_scalar('train/tflops', avg_tflops_per_gpu)
message_hub.update_scalar('train/tokens_per_sec',
tokens_per_sec_per_gpu)