Spaces:
Sleeping
Sleeping
File size: 1,080 Bytes
ed00004 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from lightning_utilities.core.rank_zero import rank_zero_only
@rank_zero_only
def calculate_model_params(model):
params = {}
params["model/params/total"] = sum(p.numel() for p in model.parameters())
params["model/params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
params["model/params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)
print(f"Total params: {params['model/params/total']/1e6:.2f}M")
print(f"Trainable params: {params['model/params/trainable']/1e6:.2f}M")
print(f"Non-trainable params: {params['model/params/non_trainable']/1e6:.2f}M")
return params
def print_dist(message):
"""
Function to print a message only on device 0 in a distributed training setup.
Args:
message (str): The message to be printed.
"""
import torch
if torch.distributed.is_initialized(): # type: ignore
if torch.distributed.get_rank() == 0: # type: ignore
print(message)
else:
print(message)
|