IlayMalinyak
sanity check
766ed77
import os
import torch
import torch.distributed as dist
from torch.nn.utils.rnn import pad_sequence
def collate_fn(batch):
# Extract audio arrays and FFT data from the batch of dictionaries
audio_arrays = [item['audio']['array'] for item in batch]
fft_arrays = [item['audio']['fft_mag'] for item in batch]
# cwt_arrays = [torch.tensor(item['audio']['cwt_mag']) for item in batch]
features = [item['audio']['features'] for item in batch]
# features_arr = torch.stack([item['audio']['features_arr'] for item in batch])
labels = [torch.tensor(item['label']) for item in batch]
# Pad both sequences
padded_audio = pad_sequence(audio_arrays, batch_first=True, padding_value=0)
padded_fft = pad_sequence(fft_arrays, batch_first=True, padding_value=0)
# padded_features = pad_sequence(features_arr, batch_first=True, padding_value=0)
# Return as dictionary with the same structure
return {
'audio': {
'array': padded_audio,
'fft_mag': padded_fft,
'features': features,
# 'features_arr': features_arr,
# 'cwt_mag': padded_cwt,
},
'label': torch.stack(labels)
}
class Container(object):
'''A container class that can be used to store any attributes.'''
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def load_dict(self, dict):
for key, value in dict.items():
if getattr(self, key, None) is None:
setattr(self, key, value)
def print_attributes(self):
for key, value in vars(self).items():
print(f"{key}: {value}")
def get_dict(self):
return self.__dict__
def setup():
"""
Setup the distributed training environment.
"""
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["SLURM_PROCID"])
jobid = int(os.environ["SLURM_JOBID"])
gpus_per_node = torch.cuda.device_count()
print('jobid ', jobid)
print('gpus per node ', gpus_per_node)
print(f"Hello from rank {rank} of {world_size} where there are" \
f" {gpus_per_node} allocated GPUs per node. ", flush=True)
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
if rank == 0: print(f"Group initialized? {dist.is_initialized()}", flush=True)
local_rank = rank - gpus_per_node * (rank // gpus_per_node)
torch.cuda.set_device(local_rank)
print(f"rank: {rank}, local_rank: {local_rank}")
return local_rank, world_size, gpus_per_node