File size: 266 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import sys
import torch
import matplotlib.pyplot as plt
ckpt_path = sys.argv[1]
imgname = sys.argv[2]
ckpt = torch.load(ckpt_path)
weights = ckpt['Classifier']['weight']
norm = weights.abs() / weights.abs().sum()
plt.plot(norm.cpu().numpy())
plt.savefig(imgname)
|