imagenet-benchmark / measure_mac.py
andreysher's picture
imagenet-benchmark
4d679c2
raw
history blame
540 Bytes
import argparse
import torch
from fvcore.nn import FlopCountAnalysis
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model-ckpt", type=str)
return parser.parse_args()
def main():
args = get_args()
checkpoint = torch.load(args.model_ckpt, map_location="cpu")
model = checkpoint["model_ckpt"]
model.eval()
flops = FlopCountAnalysis(model.cpu(), torch.ones((1, 3, 224, 224)))
flops = flops.total()
print(f"MMACs = {flops/1e6}")
if __name__ == "__main__":
main()