File size: 946 Bytes
a5c5b03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import numpy as np
import torch
def print_arch(model, model_name='model'):
print(f"| {model_name} Arch: ", model)
num_params(model, model_name=model_name)
def num_params(model, print_out=True, model_name="model"):
parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
if print_out:
print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
return parameters
def get_device_of_model(model):
return model.parameters().__next__().device
def requires_grad(model):
if isinstance(model, torch.nn.Module):
for p in model.parameters():
p.requires_grad = True
else:
model.requires_grad = True
def not_requires_grad(model):
if isinstance(model, torch.nn.Module):
for p in model.parameters():
p.requires_grad = False
else:
model.requires_grad = False
|