Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import logging | |
from typing import Optional, Union | |
import torch | |
from mmengine import print_log | |
from mmengine.hooks import Hook | |
from mmengine.model.wrappers import is_model_wrapper | |
from torch.utils._pytree import tree_flatten | |
from xtuner.parallel.sequence import get_sequence_parallel_world_size | |
DATA_BATCH = Optional[Union[dict, tuple, list]] | |
class ThroughputHook(Hook): | |
# priority must be higher than LoggerHook (50) and lower than | |
# IterTimerHook (60) | |
priority = 55 | |
def __init__(self, | |
use_activation_checkpointing=None, | |
hidden_size=None, | |
num_layers=None, | |
vocab_size=None, | |
mlp_ratio=None, | |
is_casual=None): | |
self.use_activation_checkpointing = use_activation_checkpointing | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.vocab_size = vocab_size | |
self.mlp_ratio = mlp_ratio | |
self.is_casual = is_casual | |
def _guess_is_casual_attn(model): | |
for module in model.modules(): | |
if hasattr(module, 'is_causal'): | |
return module.is_causal | |
print_log( | |
'It\'s impossible to speculate whether casual attention was used, ' | |
'and FLOPs will be calculated as `casual = True`.', 'current') | |
return True | |
def _get_batch_size_and_sequence_len(data_batch): | |
data_list, _ = tree_flatten(data_batch) | |
for data in data_list: | |
if isinstance(data, torch.Tensor): | |
return data.size(0), data.size(1) | |
raise RuntimeError('No tensor found in the batch') | |
def _guess_use_activation_checkpointing(model): | |
for module in model.modules(): | |
if hasattr(module, 'gradient_checkpointing'): | |
return module.gradient_checkpointing | |
return False | |
def before_run(self, runner) -> None: | |
if is_model_wrapper(runner.model): | |
model = runner.model.module | |
else: | |
model = runner.model | |
self.use_activation_checkpointing = \ | |
(self.use_activation_checkpointing or | |
self._guess_use_activation_checkpointing(model)) | |
self.hidden_size = self.hidden_size or model.config.hidden_size | |
self.num_layers = self.num_layers or model.config.num_hidden_layers | |
self.vocab_size = self.vocab_size or model.config.vocab_size | |
self.mlp_ratio = self.mlp_ratio or (model.config.intermediate_size / | |
model.config.hidden_size) | |
self.mlp_ratio *= 1.5 # has gate_proj | |
self.is_casual = self.is_casual if self.is_casual is not None \ | |
else self._guess_is_casual_attn(model) | |
use_varlen_attn = getattr(model, 'use_varlen_attn', False) | |
if use_varlen_attn: | |
print_log( | |
'Using variable-length Flash Attention causes an inflation' | |
' in the FLOPs calculation.', | |
'current', | |
level=logging.WARNING) | |
return | |
def after_train_iter(self, | |
runner, | |
batch_idx: int, | |
data_batch: DATA_BATCH = None, | |
outputs: Optional[dict] = None) -> None: | |
"""Calc flops based on the paper of Megatron | |
https://deepakn94.github.io/assets/papers/megatron-sc21.pdf.""" | |
batch_size, sequence_len = self._get_batch_size_and_sequence_len( | |
data_batch) | |
sequence_parallel_size = get_sequence_parallel_world_size() | |
sequence_len /= sequence_parallel_size | |
message_hub = runner.message_hub | |
iter_time = message_hub.get_scalar('train/time').current() | |
# We consider a language model with π transformer layers, | |
# hidden size h, sequence length s, vocabulary size V, and | |
# training batch size B. | |
# A $A_{mxk}$ x $X_{kxn}$ matrix multiplication requires 2π Γπ Γπ FLOPs | |
# (factor of 2 needed to account for multiplies and adds). | |
# Attention Layer: | |
# qkv_proj + o_proj: 8B * s * h^2 | |
# attn: 2B * s^2 * h (casual=False) and 2B * s^2 * h / 2 (casual=True) | |
# MLP Layer: | |
# up_proj + down_proj + gate_proj: 4B * s * h^2 * mlp_ratio | |
# (In Llama mlp_ratio = intermediate_size / hidden_size * 1.5 | |
# (has gate_proj)) | |
# The backward pass requires double the number of FLOPs since we | |
# need to calculate the gradients with respect to both input and | |
# weight tensors. In addition, we are using activation recomputation, | |
# which requires an additional forward pass before the backward pass. | |
# While sequence parallel will affect the FLOPs calculation in attn. | |
# Suppose the sequence length in one GPU is s and the sequence | |
# parallel world size is `sp_size`, which means the total | |
# sequence length in the attention calculation is | |
# `s * sp_size` and the number of attention heads decrease to | |
# `num_heads / sp_size`. Hence, the FLOPs in attn calculation is: | |
# 2B * (s * sp_size)^2 * (h / sp_size) (casual=False) and | |
# 2B * (s * sp_size)^2 * (h / sp_size) / 2 (casual=True) | |
flops_qkvo_proj = 8 * batch_size * sequence_len * self.hidden_size**2 | |
flops_attn = 4 * batch_size * sequence_len**2 * self.hidden_size * \ | |
sequence_parallel_size / (int(self.is_casual) + 1) | |
flops_mlp = 4 * self.mlp_ratio * batch_size * sequence_len * \ | |
self.hidden_size**2 | |
flops_wo_head = (3 + int(self.use_activation_checkpointing)) * ( | |
flops_qkvo_proj + flops_attn + flops_mlp) * self.num_layers | |
flops_head = 3 * 2 * batch_size * sequence_len * self.hidden_size * \ | |
self.vocab_size | |
flops_per_iteration = flops_wo_head + flops_head | |
avg_tflops_per_gpu = flops_per_iteration / 1e12 / (iter_time + 1e-12) | |
tokens_per_sec_per_gpu = batch_size * sequence_len / ( | |
iter_time + 1e-12) | |
message_hub.update_scalar('train/tflops', avg_tflops_per_gpu) | |
message_hub.update_scalar('train/tokens_per_sec', | |
tokens_per_sec_per_gpu) | |