Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,130 Bytes
3f1e960 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
import os
import re
import socket
import torch
import torch.distributed
from . import training_stats
_sync_device = None
#----------------------------------------------------------------------------
def init():
global _sync_device
if not torch.distributed.is_initialized():
# Setup some reasonable defaults for env-based distributed init if
# not set by the running environment.
if 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = 'localhost'
if 'MASTER_PORT' not in os.environ:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
os.environ['MASTER_PORT'] = str(s.getsockname()[1])
s.close()
if 'RANK' not in os.environ:
os.environ['RANK'] = '0'
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = '0'
if 'WORLD_SIZE' not in os.environ:
os.environ['WORLD_SIZE'] = '1'
backend = 'gloo' if os.name == 'nt' else 'nccl'
torch.distributed.init_process_group(backend=backend, init_method='env://')
torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
_sync_device = torch.device('cuda') if get_world_size() > 1 else None
training_stats.init_multiprocessing(rank=get_rank(), sync_device=_sync_device)
#----------------------------------------------------------------------------
def get_rank():
return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
#----------------------------------------------------------------------------
def get_world_size():
return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
#----------------------------------------------------------------------------
def should_stop():
return False
#----------------------------------------------------------------------------
def should_suspend():
return False
#----------------------------------------------------------------------------
def request_suspend():
pass
#----------------------------------------------------------------------------
def update_progress(cur, total):
pass
#----------------------------------------------------------------------------
def print0(*args, **kwargs):
if get_rank() == 0:
print(*args, **kwargs)
#----------------------------------------------------------------------------
class CheckpointIO:
def __init__(self, **kwargs):
self._state_objs = kwargs
def save(self, pt_path, verbose=True):
if verbose:
print0(f'Saving {pt_path} ... ', end='', flush=True)
data = dict()
for name, obj in self._state_objs.items():
if obj is None:
data[name] = None
elif isinstance(obj, dict):
data[name] = obj
elif hasattr(obj, 'state_dict'):
data[name] = obj.state_dict()
elif hasattr(obj, '__getstate__'):
data[name] = obj.__getstate__()
elif hasattr(obj, '__dict__'):
data[name] = obj.__dict__
else:
raise ValueError(f'Invalid state object of type {type(obj).__name__}')
if get_rank() == 0:
torch.save(data, pt_path)
if verbose:
print0('done')
def load(self, pt_path, verbose=True):
if verbose:
print0(f'Loading {pt_path} ... ', end='', flush=True)
data = torch.load(pt_path, map_location=torch.device('cpu'))
for name, obj in self._state_objs.items():
if obj is None:
pass
elif isinstance(obj, dict):
obj.clear()
obj.update(data[name])
elif hasattr(obj, 'load_state_dict'):
obj.load_state_dict(data[name])
elif hasattr(obj, '__setstate__'):
obj.__setstate__(data[name])
elif hasattr(obj, '__dict__'):
obj.__dict__.clear()
obj.__dict__.update(data[name])
else:
raise ValueError(f'Invalid state object of type {type(obj).__name__}')
if verbose:
print0('done')
def load_latest(self, run_dir, pattern=r'training-state-(\d+).pt', verbose=True):
fnames = [entry.name for entry in os.scandir(run_dir) if entry.is_file() and re.fullmatch(pattern, entry.name)]
if len(fnames) == 0:
return None
pt_path = os.path.join(run_dir, max(fnames, key=lambda x: float(re.fullmatch(pattern, x).group(1))))
self.load(pt_path, verbose=verbose)
return pt_path
#----------------------------------------------------------------------------
|