lmzjms's picture
Upload 1162 files
0b32ad6 verified
import sys
import torch
import matplotlib.pyplot as plt
ckpt_path = sys.argv[1]
imgname = sys.argv[2]
ckpt = torch.load(ckpt_path)
weights = ckpt['Classifier']['weight']
norm = weights.abs() / weights.abs().sum()
plt.plot(norm.cpu().numpy())
plt.savefig(imgname)