|
|
|
|
|
|
|
|
|
|
|
import fire |
|
import torch |
|
from deepspeed.accelerator import get_accelerator |
|
from deepspeed.profiling.flops_profiler import get_model_profile |
|
|
|
from llamafactory.chat import ChatModel |
|
|
|
|
|
def calculate_flops( |
|
model_name_or_path: str, |
|
batch_size: int = 1, |
|
seq_length: int = 256, |
|
flash_attn: str = "auto", |
|
): |
|
with get_accelerator().device(0): |
|
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn)) |
|
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device) |
|
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()} |
|
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True) |
|
print("FLOPs:", flops) |
|
print("MACs:", macs) |
|
print("Params:", params) |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(calculate_flops) |
|
|