import argparse | |
import torch | |
from fvcore.nn import FlopCountAnalysis | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model-ckpt", type=str) | |
return parser.parse_args() | |
def main(): | |
args = get_args() | |
checkpoint = torch.load(args.model_ckpt, map_location="cpu") | |
model = checkpoint["model_ckpt"] | |
model.eval() | |
flops = FlopCountAnalysis(model.cpu(), torch.ones((1, 3, 224, 224))) | |
flops = flops.total() | |
print(f"MMACs = {flops/1e6}") | |
if __name__ == "__main__": | |
main() | |