ahsanMah commited on
Commit
3f1e960
·
1 Parent(s): be66f33

adding files for building model

Browse files
dnnlib/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ from .util import EasyDict, make_cache_dir_path
dnnlib/util.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Miscellaneous utility classes and functions."""
9
+
10
+ import ctypes
11
+ import fnmatch
12
+ import importlib
13
+ import inspect
14
+ import numpy as np
15
+ import os
16
+ import shutil
17
+ import sys
18
+ import types
19
+ import io
20
+ import pickle
21
+ import re
22
+ import requests
23
+ import html
24
+ import hashlib
25
+ import glob
26
+ import tempfile
27
+ import urllib
28
+ import urllib.parse
29
+ import uuid
30
+
31
+ from typing import Any, Callable, BinaryIO, List, Tuple, Union, Optional
32
+
33
+ # Util classes
34
+ # ------------------------------------------------------------------------------------------
35
+
36
+
37
+ class EasyDict(dict):
38
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
39
+
40
+ def __getattr__(self, name: str) -> Any:
41
+ try:
42
+ return self[name]
43
+ except KeyError:
44
+ raise AttributeError(name)
45
+
46
+ def __setattr__(self, name: str, value: Any) -> None:
47
+ self[name] = value
48
+
49
+ def __delattr__(self, name: str) -> None:
50
+ del self[name]
51
+
52
+
53
+ class Logger(object):
54
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
55
+
56
+ def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
57
+ self.file = None
58
+
59
+ if file_name is not None:
60
+ self.file = open(file_name, file_mode)
61
+
62
+ self.should_flush = should_flush
63
+ self.stdout = sys.stdout
64
+ self.stderr = sys.stderr
65
+
66
+ sys.stdout = self
67
+ sys.stderr = self
68
+
69
+ def __enter__(self) -> "Logger":
70
+ return self
71
+
72
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
73
+ self.close()
74
+
75
+ def write(self, text: Union[str, bytes]) -> None:
76
+ """Write text to stdout (and a file) and optionally flush."""
77
+ if isinstance(text, bytes):
78
+ text = text.decode()
79
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
80
+ return
81
+
82
+ if self.file is not None:
83
+ self.file.write(text)
84
+
85
+ self.stdout.write(text)
86
+
87
+ if self.should_flush:
88
+ self.flush()
89
+
90
+ def flush(self) -> None:
91
+ """Flush written text to both stdout and a file, if open."""
92
+ if self.file is not None:
93
+ self.file.flush()
94
+
95
+ self.stdout.flush()
96
+
97
+ def close(self) -> None:
98
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
99
+ self.flush()
100
+
101
+ # if using multiple loggers, prevent closing in wrong order
102
+ if sys.stdout is self:
103
+ sys.stdout = self.stdout
104
+ if sys.stderr is self:
105
+ sys.stderr = self.stderr
106
+
107
+ if self.file is not None:
108
+ self.file.close()
109
+ self.file = None
110
+
111
+
112
+ # Cache directories
113
+ # ------------------------------------------------------------------------------------------
114
+
115
+ _dnnlib_cache_dir = None
116
+
117
+ def set_cache_dir(path: str) -> None:
118
+ global _dnnlib_cache_dir
119
+ _dnnlib_cache_dir = path
120
+
121
+ def make_cache_dir_path(*paths: str) -> str:
122
+ if _dnnlib_cache_dir is not None:
123
+ return os.path.join(_dnnlib_cache_dir, *paths)
124
+ if 'DNNLIB_CACHE_DIR' in os.environ:
125
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
126
+ if 'HOME' in os.environ:
127
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
128
+ if 'USERPROFILE' in os.environ:
129
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
130
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
131
+
132
+ # Small util functions
133
+ # ------------------------------------------------------------------------------------------
134
+
135
+
136
+ def format_time(seconds: Union[int, float]) -> str:
137
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
138
+ s = int(np.rint(seconds))
139
+
140
+ if s < 60:
141
+ return "{0}s".format(s)
142
+ elif s < 60 * 60:
143
+ return "{0}m {1:02}s".format(s // 60, s % 60)
144
+ elif s < 24 * 60 * 60:
145
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
146
+ else:
147
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
148
+
149
+
150
+ def format_time_brief(seconds: Union[int, float]) -> str:
151
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
152
+ s = int(np.rint(seconds))
153
+
154
+ if s < 60:
155
+ return "{0}s".format(s)
156
+ elif s < 60 * 60:
157
+ return "{0}m {1:02}s".format(s // 60, s % 60)
158
+ elif s < 24 * 60 * 60:
159
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
160
+ else:
161
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
162
+
163
+
164
+ def tuple_product(t: Tuple) -> Any:
165
+ """Calculate the product of the tuple elements."""
166
+ result = 1
167
+
168
+ for v in t:
169
+ result *= v
170
+
171
+ return result
172
+
173
+
174
+ _str_to_ctype = {
175
+ "uint8": ctypes.c_ubyte,
176
+ "uint16": ctypes.c_uint16,
177
+ "uint32": ctypes.c_uint32,
178
+ "uint64": ctypes.c_uint64,
179
+ "int8": ctypes.c_byte,
180
+ "int16": ctypes.c_int16,
181
+ "int32": ctypes.c_int32,
182
+ "int64": ctypes.c_int64,
183
+ "float32": ctypes.c_float,
184
+ "float64": ctypes.c_double
185
+ }
186
+
187
+
188
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
189
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
190
+ type_str = None
191
+
192
+ if isinstance(type_obj, str):
193
+ type_str = type_obj
194
+ elif hasattr(type_obj, "__name__"):
195
+ type_str = type_obj.__name__
196
+ elif hasattr(type_obj, "name"):
197
+ type_str = type_obj.name
198
+ else:
199
+ raise RuntimeError("Cannot infer type name from input")
200
+
201
+ assert type_str in _str_to_ctype.keys()
202
+
203
+ my_dtype = np.dtype(type_str)
204
+ my_ctype = _str_to_ctype[type_str]
205
+
206
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
207
+
208
+ return my_dtype, my_ctype
209
+
210
+
211
+ def is_pickleable(obj: Any) -> bool:
212
+ try:
213
+ with io.BytesIO() as stream:
214
+ pickle.dump(obj, stream)
215
+ return True
216
+ except:
217
+ return False
218
+
219
+
220
+ # Functionality to import modules/objects by name, and call functions by name
221
+ # ------------------------------------------------------------------------------------------
222
+
223
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
224
+ """Searches for the underlying module behind the name to some python object.
225
+ Returns the module and the object name (original name with module part removed)."""
226
+
227
+ # allow convenience shorthands, substitute them by full names
228
+ obj_name = re.sub("^np.", "numpy.", obj_name)
229
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
230
+
231
+ # list alternatives for (module_name, local_obj_name)
232
+ parts = obj_name.split(".")
233
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
234
+
235
+ # try each alternative in turn
236
+ for module_name, local_obj_name in name_pairs:
237
+ try:
238
+ module = importlib.import_module(module_name) # may raise ImportError
239
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
240
+ return module, local_obj_name
241
+ except:
242
+ pass
243
+
244
+ # maybe some of the modules themselves contain errors?
245
+ for module_name, _local_obj_name in name_pairs:
246
+ try:
247
+ importlib.import_module(module_name) # may raise ImportError
248
+ except ImportError:
249
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
250
+ raise
251
+
252
+ # maybe the requested attribute is missing?
253
+ for module_name, local_obj_name in name_pairs:
254
+ try:
255
+ module = importlib.import_module(module_name) # may raise ImportError
256
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
257
+ except ImportError:
258
+ pass
259
+
260
+ # we are out of luck, but we have no idea why
261
+ raise ImportError(obj_name)
262
+
263
+
264
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
265
+ """Traverses the object name and returns the last (rightmost) python object."""
266
+ if obj_name == '':
267
+ return module
268
+ obj = module
269
+ for part in obj_name.split("."):
270
+ obj = getattr(obj, part)
271
+ return obj
272
+
273
+
274
+ def get_obj_by_name(name: str) -> Any:
275
+ """Finds the python object with the given name."""
276
+ module, obj_name = get_module_from_obj_name(name)
277
+ return get_obj_from_module(module, obj_name)
278
+
279
+
280
+ def call_func_by_name(*args, func_name: Union[str, Callable], **kwargs) -> Any:
281
+ """Finds the python object with the given name and calls it as a function."""
282
+ assert func_name is not None
283
+ func_obj = get_obj_by_name(func_name) if isinstance(func_name, str) else func_name
284
+ assert callable(func_obj)
285
+ return func_obj(*args, **kwargs)
286
+
287
+
288
+ def construct_class_by_name(*args, class_name: Union[str, type], **kwargs) -> Any:
289
+ """Finds the python class with the given name and constructs it with the given arguments."""
290
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
291
+
292
+
293
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
294
+ """Get the directory path of the module containing the given object name."""
295
+ module, _ = get_module_from_obj_name(obj_name)
296
+ return os.path.dirname(inspect.getfile(module))
297
+
298
+
299
+ def is_top_level_function(obj: Any) -> bool:
300
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
301
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
302
+
303
+
304
+ def get_top_level_function_name(obj: Any) -> str:
305
+ """Return the fully-qualified name of a top-level function."""
306
+ assert is_top_level_function(obj)
307
+ module = obj.__module__
308
+ if module == '__main__':
309
+ fname = sys.modules[module].__file__
310
+ assert fname is not None
311
+ module = os.path.splitext(os.path.basename(fname))[0]
312
+ return module + "." + obj.__name__
313
+
314
+
315
+ # File system helpers
316
+ # ------------------------------------------------------------------------------------------
317
+
318
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: Optional[List[str]] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
319
+ """List all files recursively in a given directory while ignoring given file and directory names.
320
+ Returns list of tuples containing both absolute and relative paths."""
321
+ assert os.path.isdir(dir_path)
322
+ base_name = os.path.basename(os.path.normpath(dir_path))
323
+
324
+ if ignores is None:
325
+ ignores = []
326
+
327
+ result = []
328
+
329
+ for root, dirs, files in os.walk(dir_path, topdown=True):
330
+ for ignore_ in ignores:
331
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
332
+
333
+ # dirs need to be edited in-place
334
+ for d in dirs_to_remove:
335
+ dirs.remove(d)
336
+
337
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
338
+
339
+ absolute_paths = [os.path.join(root, f) for f in files]
340
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
341
+
342
+ if add_base_to_relative:
343
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
344
+
345
+ assert len(absolute_paths) == len(relative_paths)
346
+ result += zip(absolute_paths, relative_paths)
347
+
348
+ return result
349
+
350
+
351
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
352
+ """Takes in a list of tuples of (src, dst) paths and copies files.
353
+ Will create all necessary directories."""
354
+ for file in files:
355
+ target_dir_name = os.path.dirname(file[1])
356
+
357
+ # will create all intermediate-level directories
358
+ os.makedirs(target_dir_name, exist_ok=True)
359
+ shutil.copyfile(file[0], file[1])
360
+
361
+
362
+ # URL helpers
363
+ # ------------------------------------------------------------------------------------------
364
+
365
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
366
+ """Determine whether the given object is a valid URL string."""
367
+ if not isinstance(obj, str) or not "://" in obj:
368
+ return False
369
+ if allow_file_urls and obj.startswith('file://'):
370
+ return True
371
+ try:
372
+ res = urllib.parse.urlparse(obj)
373
+ if not res.scheme or not res.netloc or not "." in res.netloc:
374
+ return False
375
+ res = urllib.parse.urlparse(urllib.parse.urljoin(obj, "/"))
376
+ if not res.scheme or not res.netloc or not "." in res.netloc:
377
+ return False
378
+ except:
379
+ return False
380
+ return True
381
+
382
+ # Note on static typing: a better API would be to split 'open_url' to 'openl_url' and
383
+ # 'download_url' with separate return types (BinaryIO, str). As the `return_filename=True`
384
+ # case is somewhat uncommon, we just pretend like this function never returns a string
385
+ # and type ignore return value for those cases.
386
+ def open_url(url: str, cache_dir: Optional[str] = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> BinaryIO:
387
+ """Download the given URL and return a binary-mode file object to access the data."""
388
+ assert num_attempts >= 1
389
+ assert not (return_filename and (not cache))
390
+
391
+ # Doesn't look like an URL scheme so interpret it as a local filename.
392
+ if not re.match('^[a-z]+://', url):
393
+ return url if return_filename else open(url, "rb") # type: ignore
394
+
395
+ # Handle file URLs. This code handles unusual file:// patterns that
396
+ # arise on Windows:
397
+ #
398
+ # file:///c:/foo.txt
399
+ #
400
+ # which would translate to a local '/c:/foo.txt' filename that's
401
+ # invalid. Drop the forward slash for such pathnames.
402
+ #
403
+ # If you touch this code path, you should test it on both Linux and
404
+ # Windows.
405
+ #
406
+ # Some internet resources suggest using urllib.request.url2pathname()
407
+ # but that converts forward slashes to backslashes and this causes
408
+ # its own set of problems.
409
+ if url.startswith('file://'):
410
+ filename = urllib.parse.urlparse(url).path
411
+ if re.match(r'^/[a-zA-Z]:', filename):
412
+ filename = filename[1:]
413
+ return filename if return_filename else open(filename, "rb") # type: ignore
414
+
415
+ assert is_url(url)
416
+
417
+ # Lookup from cache.
418
+ if cache_dir is None:
419
+ cache_dir = make_cache_dir_path('downloads')
420
+
421
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
422
+ if cache:
423
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
424
+ if len(cache_files) == 1:
425
+ filename = cache_files[0]
426
+ return filename if return_filename else open(filename, "rb") # type: ignore
427
+
428
+ # Download.
429
+ url_name = None
430
+ url_data = None
431
+ with requests.Session() as session:
432
+ if verbose:
433
+ print("Downloading %s ..." % url, end="", flush=True)
434
+ for attempts_left in reversed(range(num_attempts)):
435
+ try:
436
+ with session.get(url) as res:
437
+ res.raise_for_status()
438
+ if len(res.content) == 0:
439
+ raise IOError("No data received")
440
+
441
+ if len(res.content) < 8192:
442
+ content_str = res.content.decode("utf-8")
443
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
444
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
445
+ if len(links) == 1:
446
+ url = urllib.parse.urljoin(url, links[0])
447
+ raise IOError("Google Drive virus checker nag")
448
+ if "Google Drive - Quota exceeded" in content_str:
449
+ raise IOError("Google Drive download quota exceeded -- please try again later")
450
+
451
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
452
+ url_name = match[1] if match else url
453
+ url_data = res.content
454
+ if verbose:
455
+ print(" done")
456
+ break
457
+ except KeyboardInterrupt:
458
+ raise
459
+ except:
460
+ if not attempts_left:
461
+ if verbose:
462
+ print(" failed")
463
+ raise
464
+ if verbose:
465
+ print(".", end="", flush=True)
466
+
467
+ assert url_data is not None
468
+
469
+ # Save to cache.
470
+ if cache:
471
+ assert url_name is not None
472
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
473
+ safe_name = safe_name[:min(len(safe_name), 128)]
474
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
475
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
476
+ os.makedirs(cache_dir, exist_ok=True)
477
+ with open(temp_file, "wb") as f:
478
+ f.write(url_data)
479
+ os.replace(temp_file, cache_file) # atomic
480
+ if return_filename:
481
+ return cache_file # type: ignore
482
+
483
+ # Return data as file object.
484
+ assert not return_filename
485
+ return io.BytesIO(url_data)
torch_utils/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ # empty
torch_utils/distributed.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ import os
9
+ import re
10
+ import socket
11
+ import torch
12
+ import torch.distributed
13
+ from . import training_stats
14
+
15
+ _sync_device = None
16
+
17
+ #----------------------------------------------------------------------------
18
+
19
+ def init():
20
+ global _sync_device
21
+
22
+ if not torch.distributed.is_initialized():
23
+ # Setup some reasonable defaults for env-based distributed init if
24
+ # not set by the running environment.
25
+ if 'MASTER_ADDR' not in os.environ:
26
+ os.environ['MASTER_ADDR'] = 'localhost'
27
+ if 'MASTER_PORT' not in os.environ:
28
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
29
+ s.bind(('', 0))
30
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
31
+ os.environ['MASTER_PORT'] = str(s.getsockname()[1])
32
+ s.close()
33
+ if 'RANK' not in os.environ:
34
+ os.environ['RANK'] = '0'
35
+ if 'LOCAL_RANK' not in os.environ:
36
+ os.environ['LOCAL_RANK'] = '0'
37
+ if 'WORLD_SIZE' not in os.environ:
38
+ os.environ['WORLD_SIZE'] = '1'
39
+ backend = 'gloo' if os.name == 'nt' else 'nccl'
40
+ torch.distributed.init_process_group(backend=backend, init_method='env://')
41
+ torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
42
+
43
+ _sync_device = torch.device('cuda') if get_world_size() > 1 else None
44
+ training_stats.init_multiprocessing(rank=get_rank(), sync_device=_sync_device)
45
+
46
+ #----------------------------------------------------------------------------
47
+
48
+ def get_rank():
49
+ return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
50
+
51
+ #----------------------------------------------------------------------------
52
+
53
+ def get_world_size():
54
+ return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
55
+
56
+ #----------------------------------------------------------------------------
57
+
58
+ def should_stop():
59
+ return False
60
+
61
+ #----------------------------------------------------------------------------
62
+
63
+ def should_suspend():
64
+ return False
65
+
66
+ #----------------------------------------------------------------------------
67
+
68
+ def request_suspend():
69
+ pass
70
+
71
+ #----------------------------------------------------------------------------
72
+
73
+ def update_progress(cur, total):
74
+ pass
75
+
76
+ #----------------------------------------------------------------------------
77
+
78
+ def print0(*args, **kwargs):
79
+ if get_rank() == 0:
80
+ print(*args, **kwargs)
81
+
82
+ #----------------------------------------------------------------------------
83
+
84
+ class CheckpointIO:
85
+ def __init__(self, **kwargs):
86
+ self._state_objs = kwargs
87
+
88
+ def save(self, pt_path, verbose=True):
89
+ if verbose:
90
+ print0(f'Saving {pt_path} ... ', end='', flush=True)
91
+ data = dict()
92
+ for name, obj in self._state_objs.items():
93
+ if obj is None:
94
+ data[name] = None
95
+ elif isinstance(obj, dict):
96
+ data[name] = obj
97
+ elif hasattr(obj, 'state_dict'):
98
+ data[name] = obj.state_dict()
99
+ elif hasattr(obj, '__getstate__'):
100
+ data[name] = obj.__getstate__()
101
+ elif hasattr(obj, '__dict__'):
102
+ data[name] = obj.__dict__
103
+ else:
104
+ raise ValueError(f'Invalid state object of type {type(obj).__name__}')
105
+ if get_rank() == 0:
106
+ torch.save(data, pt_path)
107
+ if verbose:
108
+ print0('done')
109
+
110
+ def load(self, pt_path, verbose=True):
111
+ if verbose:
112
+ print0(f'Loading {pt_path} ... ', end='', flush=True)
113
+ data = torch.load(pt_path, map_location=torch.device('cpu'))
114
+ for name, obj in self._state_objs.items():
115
+ if obj is None:
116
+ pass
117
+ elif isinstance(obj, dict):
118
+ obj.clear()
119
+ obj.update(data[name])
120
+ elif hasattr(obj, 'load_state_dict'):
121
+ obj.load_state_dict(data[name])
122
+ elif hasattr(obj, '__setstate__'):
123
+ obj.__setstate__(data[name])
124
+ elif hasattr(obj, '__dict__'):
125
+ obj.__dict__.clear()
126
+ obj.__dict__.update(data[name])
127
+ else:
128
+ raise ValueError(f'Invalid state object of type {type(obj).__name__}')
129
+ if verbose:
130
+ print0('done')
131
+
132
+ def load_latest(self, run_dir, pattern=r'training-state-(\d+).pt', verbose=True):
133
+ fnames = [entry.name for entry in os.scandir(run_dir) if entry.is_file() and re.fullmatch(pattern, entry.name)]
134
+ if len(fnames) == 0:
135
+ return None
136
+ pt_path = os.path.join(run_dir, max(fnames, key=lambda x: float(re.fullmatch(pattern, x).group(1))))
137
+ self.load(pt_path, verbose=verbose)
138
+ return pt_path
139
+
140
+ #----------------------------------------------------------------------------
torch_utils/misc.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ import re
9
+ import contextlib
10
+ import functools
11
+ import numpy as np
12
+ import torch
13
+ import warnings
14
+ import dnnlib
15
+
16
+ #----------------------------------------------------------------------------
17
+ # Re-seed torch & numpy random generators based on the given arguments.
18
+
19
+ def set_random_seed(*args):
20
+ seed = hash(args) % (1 << 31)
21
+ torch.manual_seed(seed)
22
+ np.random.seed(seed)
23
+
24
+ #----------------------------------------------------------------------------
25
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
26
+ # same constant is used multiple times.
27
+
28
+ _constant_cache = dict()
29
+
30
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
31
+ value = np.asarray(value)
32
+ if shape is not None:
33
+ shape = tuple(shape)
34
+ if dtype is None:
35
+ dtype = torch.get_default_dtype()
36
+ if device is None:
37
+ device = torch.device('cpu')
38
+ if memory_format is None:
39
+ memory_format = torch.contiguous_format
40
+
41
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
42
+ tensor = _constant_cache.get(key, None)
43
+ if tensor is None:
44
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
45
+ if shape is not None:
46
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
47
+ tensor = tensor.contiguous(memory_format=memory_format)
48
+ _constant_cache[key] = tensor
49
+ return tensor
50
+
51
+ #----------------------------------------------------------------------------
52
+ # Variant of constant() that inherits dtype and device from the given
53
+ # reference tensor by default.
54
+
55
+ def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
56
+ if dtype is None:
57
+ dtype = ref.dtype
58
+ if device is None:
59
+ device = ref.device
60
+ return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
61
+
62
+ #----------------------------------------------------------------------------
63
+ # Cached construction of temporary tensors in pinned CPU memory.
64
+
65
+ @functools.lru_cache(None)
66
+ def pinned_buf(shape, dtype):
67
+ return torch.empty(shape, dtype=dtype).pin_memory()
68
+
69
+ #----------------------------------------------------------------------------
70
+ # Symbolic assert.
71
+
72
+ try:
73
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
74
+ except AttributeError:
75
+ symbolic_assert = torch.Assert # 1.7.0
76
+
77
+ #----------------------------------------------------------------------------
78
+ # Context manager to temporarily suppress known warnings in torch.jit.trace().
79
+ # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
80
+
81
+ @contextlib.contextmanager
82
+ def suppress_tracer_warnings():
83
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
84
+ warnings.filters.insert(0, flt)
85
+ yield
86
+ warnings.filters.remove(flt)
87
+
88
+ #----------------------------------------------------------------------------
89
+ # Assert that the shape of a tensor matches the given list of integers.
90
+ # None indicates that the size of a dimension is allowed to vary.
91
+ # Performs symbolic assertion when used in torch.jit.trace().
92
+
93
+ def assert_shape(tensor, ref_shape):
94
+ if tensor.ndim != len(ref_shape):
95
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
96
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
97
+ if ref_size is None:
98
+ pass
99
+ elif isinstance(ref_size, torch.Tensor):
100
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
101
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
102
+ elif isinstance(size, torch.Tensor):
103
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
104
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
105
+ elif size != ref_size:
106
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
107
+
108
+ #----------------------------------------------------------------------------
109
+ # Function decorator that calls torch.autograd.profiler.record_function().
110
+
111
+ def profiled_function(fn):
112
+ def decorator(*args, **kwargs):
113
+ with torch.autograd.profiler.record_function(fn.__name__):
114
+ return fn(*args, **kwargs)
115
+ decorator.__name__ = fn.__name__
116
+ return decorator
117
+
118
+ #----------------------------------------------------------------------------
119
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
120
+ # indefinitely, shuffling items as it goes.
121
+
122
+ class InfiniteSampler(torch.utils.data.Sampler):
123
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, start_idx=0):
124
+ assert len(dataset) > 0
125
+ assert num_replicas > 0
126
+ assert 0 <= rank < num_replicas
127
+ warnings.filterwarnings('ignore', '`data_source` argument is not used and will be removed')
128
+ super().__init__(dataset)
129
+ self.dataset_size = len(dataset)
130
+ self.start_idx = start_idx + rank
131
+ self.stride = num_replicas
132
+ self.shuffle = shuffle
133
+ self.seed = seed
134
+
135
+ def __iter__(self):
136
+ idx = self.start_idx
137
+ epoch = None
138
+ while True:
139
+ if epoch != idx // self.dataset_size:
140
+ epoch = idx // self.dataset_size
141
+ order = np.arange(self.dataset_size)
142
+ if self.shuffle:
143
+ np.random.RandomState(hash((self.seed, epoch)) % (1 << 31)).shuffle(order)
144
+ yield int(order[idx % self.dataset_size])
145
+ idx += self.stride
146
+
147
+ #----------------------------------------------------------------------------
148
+ # Utilities for operating with torch.nn.Module parameters and buffers.
149
+
150
+ def params_and_buffers(module):
151
+ assert isinstance(module, torch.nn.Module)
152
+ return list(module.parameters()) + list(module.buffers())
153
+
154
+ def named_params_and_buffers(module):
155
+ assert isinstance(module, torch.nn.Module)
156
+ return list(module.named_parameters()) + list(module.named_buffers())
157
+
158
+ @torch.no_grad()
159
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
160
+ assert isinstance(src_module, torch.nn.Module)
161
+ assert isinstance(dst_module, torch.nn.Module)
162
+ src_tensors = dict(named_params_and_buffers(src_module))
163
+ for name, tensor in named_params_and_buffers(dst_module):
164
+ assert (name in src_tensors) or (not require_all)
165
+ if name in src_tensors:
166
+ tensor.copy_(src_tensors[name])
167
+
168
+ #----------------------------------------------------------------------------
169
+ # Context manager for easily enabling/disabling DistributedDataParallel
170
+ # synchronization.
171
+
172
+ @contextlib.contextmanager
173
+ def ddp_sync(module, sync):
174
+ assert isinstance(module, torch.nn.Module)
175
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
176
+ yield
177
+ else:
178
+ with module.no_sync():
179
+ yield
180
+
181
+ #----------------------------------------------------------------------------
182
+ # Check DistributedDataParallel consistency across processes.
183
+
184
+ def check_ddp_consistency(module, ignore_regex=None):
185
+ assert isinstance(module, torch.nn.Module)
186
+ for name, tensor in named_params_and_buffers(module):
187
+ fullname = type(module).__name__ + '.' + name
188
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
189
+ continue
190
+ tensor = tensor.detach()
191
+ if tensor.is_floating_point():
192
+ tensor = torch.nan_to_num(tensor)
193
+ other = tensor.clone()
194
+ torch.distributed.broadcast(tensor=other, src=0)
195
+ assert (tensor == other).all(), fullname
196
+
197
+ #----------------------------------------------------------------------------
198
+ # Print summary table of module hierarchy.
199
+
200
+ @torch.no_grad()
201
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
202
+ assert isinstance(module, torch.nn.Module)
203
+ assert not isinstance(module, torch.jit.ScriptModule)
204
+ assert isinstance(inputs, (tuple, list))
205
+
206
+ # Register hooks.
207
+ entries = []
208
+ nesting = [0]
209
+ def pre_hook(_mod, _inputs):
210
+ nesting[0] += 1
211
+ def post_hook(mod, _inputs, outputs):
212
+ nesting[0] -= 1
213
+ if nesting[0] <= max_nesting:
214
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
215
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
216
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
217
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
218
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
219
+
220
+ # Run module.
221
+ outputs = module(*inputs)
222
+ for hook in hooks:
223
+ hook.remove()
224
+
225
+ # Identify unique outputs, parameters, and buffers.
226
+ tensors_seen = set()
227
+ for e in entries:
228
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
229
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
230
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
231
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
232
+
233
+ # Filter out redundant entries.
234
+ if skip_redundant:
235
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
236
+
237
+ # Construct table.
238
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
239
+ rows += [['---'] * len(rows[0])]
240
+ param_total = 0
241
+ buffer_total = 0
242
+ submodule_names = {mod: name for name, mod in module.named_modules()}
243
+ for e in entries:
244
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
245
+ param_size = sum(t.numel() for t in e.unique_params)
246
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
247
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
248
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
249
+ rows += [[
250
+ name + (':0' if len(e.outputs) >= 2 else ''),
251
+ str(param_size) if param_size else '-',
252
+ str(buffer_size) if buffer_size else '-',
253
+ (output_shapes + ['-'])[0],
254
+ (output_dtypes + ['-'])[0],
255
+ ]]
256
+ for idx in range(1, len(e.outputs)):
257
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
258
+ param_total += param_size
259
+ buffer_total += buffer_size
260
+ rows += [['---'] * len(rows[0])]
261
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
262
+
263
+ # Print table.
264
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
265
+ print()
266
+ for row in rows:
267
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
268
+ print()
269
+
270
+ #----------------------------------------------------------------------------
271
+ # Tile a batch of images into a 2D grid.
272
+
273
+ def tile_images(x, w, h):
274
+ assert x.ndim == 4 # NCHW => CHW
275
+ return x.reshape(h, w, *x.shape[1:]).permute(2, 0, 3, 1, 4).reshape(x.shape[1], h * x.shape[2], w * x.shape[3])
276
+
277
+ #----------------------------------------------------------------------------
torch_utils/persistence.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Facilities for pickling Python code alongside other data.
9
+
10
+ The pickled code is automatically imported into a separate Python module
11
+ during unpickling. This way, any previously exported pickles will remain
12
+ usable even if the original code is no longer available, or if the current
13
+ version of the code is not consistent with what was originally pickled."""
14
+
15
+ import sys
16
+ import pickle
17
+ import io
18
+ import inspect
19
+ import copy
20
+ import uuid
21
+ import types
22
+ import functools
23
+ import dnnlib
24
+
25
+ #----------------------------------------------------------------------------
26
+
27
+ _version = 6 # internal version number
28
+ _decorators = set() # {decorator_class, ...}
29
+ _import_hooks = [] # [hook_function, ...]
30
+ _module_to_src_dict = dict() # {module: src, ...}
31
+ _src_to_module_dict = dict() # {src: module, ...}
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ def persistent_class(orig_class):
36
+ r"""Class decorator that extends a given class to save its source code
37
+ when pickled.
38
+
39
+ Example:
40
+
41
+ from torch_utils import persistence
42
+
43
+ @persistence.persistent_class
44
+ class MyNetwork(torch.nn.Module):
45
+ def __init__(self, num_inputs, num_outputs):
46
+ super().__init__()
47
+ self.fc = MyLayer(num_inputs, num_outputs)
48
+ ...
49
+
50
+ @persistence.persistent_class
51
+ class MyLayer(torch.nn.Module):
52
+ ...
53
+
54
+ When pickled, any instance of `MyNetwork` and `MyLayer` will save its
55
+ source code alongside other internal state (e.g., parameters, buffers,
56
+ and submodules). This way, any previously exported pickle will remain
57
+ usable even if the class definitions have been modified or are no
58
+ longer available.
59
+
60
+ The decorator saves the source code of the entire Python module
61
+ containing the decorated class. It does *not* save the source code of
62
+ any imported modules. Thus, the imported modules must be available
63
+ during unpickling, also including `torch_utils.persistence` itself.
64
+
65
+ It is ok to call functions defined in the same module from the
66
+ decorated class. However, if the decorated class depends on other
67
+ classes defined in the same module, they must be decorated as well.
68
+ This is illustrated in the above example in the case of `MyLayer`.
69
+
70
+ It is also possible to employ the decorator just-in-time before
71
+ calling the constructor. For example:
72
+
73
+ cls = MyLayer
74
+ if want_to_make_it_persistent:
75
+ cls = persistence.persistent_class(cls)
76
+ layer = cls(num_inputs, num_outputs)
77
+
78
+ As an additional feature, the decorator also keeps track of the
79
+ arguments that were used to construct each instance of the decorated
80
+ class. The arguments can be queried via `obj.init_args` and
81
+ `obj.init_kwargs`, and they are automatically pickled alongside other
82
+ object state. This feature can be disabled on a per-instance basis
83
+ by setting `self._record_init_args = False` in the constructor.
84
+
85
+ A typical use case is to first unpickle a previous instance of a
86
+ persistent class, and then upgrade it to use the latest version of
87
+ the source code:
88
+
89
+ with open('old_pickle.pkl', 'rb') as f:
90
+ old_net = pickle.load(f)
91
+ new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
92
+ misc.copy_params_and_buffers(old_net, new_net, require_all=True)
93
+ """
94
+ assert isinstance(orig_class, type)
95
+ if is_persistent(orig_class):
96
+ return orig_class
97
+
98
+ assert orig_class.__module__ in sys.modules
99
+ orig_module = sys.modules[orig_class.__module__]
100
+ orig_module_src = _module_to_src(orig_module)
101
+
102
+ @functools.wraps(orig_class, updated=())
103
+ class Decorator(orig_class):
104
+ _orig_module_src = orig_module_src
105
+ _orig_class_name = orig_class.__name__
106
+
107
+ def __init__(self, *args, **kwargs):
108
+ super().__init__(*args, **kwargs)
109
+ record_init_args = getattr(self, '_record_init_args', True)
110
+ self._init_args = copy.deepcopy(args) if record_init_args else None
111
+ self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
112
+ assert orig_class.__name__ in orig_module.__dict__
113
+ _check_pickleable(self.__reduce__())
114
+
115
+ @property
116
+ def init_args(self):
117
+ assert self._init_args is not None
118
+ return copy.deepcopy(self._init_args)
119
+
120
+ @property
121
+ def init_kwargs(self):
122
+ assert self._init_kwargs is not None
123
+ return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
124
+
125
+ def __reduce__(self):
126
+ fields = list(super().__reduce__())
127
+ fields += [None] * max(3 - len(fields), 0)
128
+ if fields[0] is not _reconstruct_persistent_obj:
129
+ meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
130
+ fields[0] = _reconstruct_persistent_obj # reconstruct func
131
+ fields[1] = (meta,) # reconstruct args
132
+ fields[2] = None # state dict
133
+ return tuple(fields)
134
+
135
+ _decorators.add(Decorator)
136
+ return Decorator
137
+
138
+ #----------------------------------------------------------------------------
139
+
140
+ def is_persistent(obj):
141
+ r"""Test whether the given object or class is persistent, i.e.,
142
+ whether it will save its source code when pickled.
143
+ """
144
+ try:
145
+ if obj in _decorators:
146
+ return True
147
+ except TypeError:
148
+ pass
149
+ return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
150
+
151
+ #----------------------------------------------------------------------------
152
+
153
+ def import_hook(hook):
154
+ r"""Register an import hook that is called whenever a persistent object
155
+ is being unpickled. A typical use case is to patch the pickled source
156
+ code to avoid errors and inconsistencies when the API of some imported
157
+ module has changed.
158
+
159
+ The hook should have the following signature:
160
+
161
+ hook(meta) -> modified meta
162
+
163
+ `meta` is an instance of `dnnlib.EasyDict` with the following fields:
164
+
165
+ type: Type of the persistent object, e.g. `'class'`.
166
+ version: Internal version number of `torch_utils.persistence`.
167
+ module_src Original source code of the Python module.
168
+ class_name: Class name in the original Python module.
169
+ state: Internal state of the object.
170
+
171
+ Example:
172
+
173
+ @persistence.import_hook
174
+ def wreck_my_network(meta):
175
+ if meta.class_name == 'MyNetwork':
176
+ print('MyNetwork is being imported. I will wreck it!')
177
+ meta.module_src = meta.module_src.replace("True", "False")
178
+ return meta
179
+ """
180
+ assert callable(hook)
181
+ _import_hooks.append(hook)
182
+
183
+ #----------------------------------------------------------------------------
184
+
185
+ def _reconstruct_persistent_obj(meta):
186
+ r"""Hook that is called internally by the `pickle` module to unpickle
187
+ a persistent object.
188
+ """
189
+ meta = dnnlib.EasyDict(meta)
190
+ meta.state = dnnlib.EasyDict(meta.state)
191
+ for hook in _import_hooks:
192
+ meta = hook(meta)
193
+ assert meta is not None
194
+
195
+ assert meta.version == _version
196
+ module = _src_to_module(meta.module_src)
197
+
198
+ assert meta.type == 'class'
199
+ orig_class = module.__dict__[meta.class_name]
200
+ decorator_class = persistent_class(orig_class)
201
+ obj = decorator_class.__new__(decorator_class)
202
+
203
+ setstate = getattr(obj, '__setstate__', None)
204
+ if callable(setstate):
205
+ setstate(meta.state) # pylint: disable=not-callable
206
+ else:
207
+ obj.__dict__.update(meta.state)
208
+ return obj
209
+
210
+ #----------------------------------------------------------------------------
211
+
212
+ def _module_to_src(module):
213
+ r"""Query the source code of a given Python module.
214
+ """
215
+ src = _module_to_src_dict.get(module, None)
216
+ if src is None:
217
+ src = inspect.getsource(module)
218
+ _module_to_src_dict[module] = src
219
+ _src_to_module_dict[src] = module
220
+ return src
221
+
222
+ def _src_to_module(src):
223
+ r"""Get or create a Python module for the given source code.
224
+ """
225
+ module = _src_to_module_dict.get(src, None)
226
+ if module is None:
227
+ module_name = "_imported_module_" + uuid.uuid4().hex
228
+ module = types.ModuleType(module_name)
229
+ sys.modules[module_name] = module
230
+ _module_to_src_dict[module] = src
231
+ _src_to_module_dict[src] = module
232
+ exec(src, module.__dict__) # pylint: disable=exec-used
233
+ return module
234
+
235
+ #----------------------------------------------------------------------------
236
+
237
+ def _check_pickleable(obj):
238
+ r"""Check that the given object is pickleable, raising an exception if
239
+ it is not. This function is expected to be considerably more efficient
240
+ than actually pickling the object.
241
+ """
242
+ def recurse(obj):
243
+ if isinstance(obj, (list, tuple, set)):
244
+ return [recurse(x) for x in obj]
245
+ if isinstance(obj, dict):
246
+ return [[recurse(x), recurse(y)] for x, y in obj.items()]
247
+ if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
248
+ return None # Python primitive types are pickleable.
249
+ if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
250
+ return None # NumPy arrays and PyTorch tensors are pickleable.
251
+ if is_persistent(obj):
252
+ return None # Persistent objects are pickleable, by virtue of the constructor check.
253
+ return obj
254
+ with io.BytesIO() as f:
255
+ pickle.dump(recurse(obj), f)
256
+
257
+ #----------------------------------------------------------------------------
torch_utils/training_stats.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Facilities for reporting and collecting training statistics across
9
+ multiple processes and devices. The interface is designed to minimize
10
+ synchronization overhead as well as the amount of boilerplate in user
11
+ code."""
12
+
13
+ import re
14
+ import numpy as np
15
+ import torch
16
+ import dnnlib
17
+
18
+ from . import misc
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
23
+ _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
24
+ _counter_dtype = torch.float64 # Data type to use for the internal counters.
25
+ _rank = 0 # Rank of the current process.
26
+ _sync_device = None # Device to use for multiprocess communication. None = single-process.
27
+ _sync_called = False # Has _sync() been called yet?
28
+ _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
29
+ _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
30
+
31
+ #----------------------------------------------------------------------------
32
+
33
+ def init_multiprocessing(rank, sync_device):
34
+ r"""Initializes `torch_utils.training_stats` for collecting statistics
35
+ across multiple processes.
36
+
37
+ This function must be called after
38
+ `torch.distributed.init_process_group()` and before `Collector.update()`.
39
+ The call is not necessary if multi-process collection is not needed.
40
+
41
+ Args:
42
+ rank: Rank of the current process.
43
+ sync_device: PyTorch device to use for inter-process
44
+ communication, or None to disable multi-process
45
+ collection. Typically `torch.device('cuda', rank)`.
46
+ """
47
+ global _rank, _sync_device
48
+ assert not _sync_called
49
+ _rank = rank
50
+ _sync_device = sync_device
51
+
52
+ #----------------------------------------------------------------------------
53
+
54
+ @misc.profiled_function
55
+ def report(name, value):
56
+ r"""Broadcasts the given set of scalars to all interested instances of
57
+ `Collector`, across device and process boundaries. NaNs and Infs are
58
+ ignored.
59
+
60
+ This function is expected to be extremely cheap and can be safely
61
+ called from anywhere in the training loop, loss function, or inside a
62
+ `torch.nn.Module`.
63
+
64
+ Warning: The current implementation expects the set of unique names to
65
+ be consistent across processes. Please make sure that `report()` is
66
+ called at least once for each unique name by each process, and in the
67
+ same order. If a given process has no scalars to broadcast, it can do
68
+ `report(name, [])` (empty list).
69
+
70
+ Args:
71
+ name: Arbitrary string specifying the name of the statistic.
72
+ Averages are accumulated separately for each unique name.
73
+ value: Arbitrary set of scalars. Can be a list, tuple,
74
+ NumPy array, PyTorch tensor, or Python scalar.
75
+
76
+ Returns:
77
+ The same `value` that was passed in.
78
+ """
79
+ if name not in _counters:
80
+ _counters[name] = dict()
81
+
82
+ elems = torch.as_tensor(value)
83
+ if elems.numel() == 0:
84
+ return value
85
+
86
+ elems = elems.detach().flatten().to(_reduce_dtype)
87
+ square = elems.square()
88
+ finite = square.isfinite()
89
+ moments = torch.stack([
90
+ finite.sum(dtype=_reduce_dtype),
91
+ torch.where(finite, elems, 0).sum(),
92
+ torch.where(finite, square, 0).sum(),
93
+ ])
94
+ assert moments.ndim == 1 and moments.shape[0] == _num_moments
95
+ moments = moments.to(_counter_dtype)
96
+
97
+ device = moments.device
98
+ if device not in _counters[name]:
99
+ _counters[name][device] = torch.zeros_like(moments)
100
+ _counters[name][device].add_(moments)
101
+ return value
102
+
103
+ #----------------------------------------------------------------------------
104
+
105
+ def report0(name, value):
106
+ r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
107
+ but ignores any scalars provided by the other processes.
108
+ See `report()` for further details.
109
+ """
110
+ report(name, value if _rank == 0 else [])
111
+ return value
112
+
113
+ #----------------------------------------------------------------------------
114
+
115
+ class Collector:
116
+ r"""Collects the scalars broadcasted by `report()` and `report0()` and
117
+ computes their long-term averages (mean and standard deviation) over
118
+ user-defined periods of time.
119
+
120
+ The averages are first collected into internal counters that are not
121
+ directly visible to the user. They are then copied to the user-visible
122
+ state as a result of calling `update()` and can then be queried using
123
+ `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
124
+ internal counters for the next round, so that the user-visible state
125
+ effectively reflects averages collected between the last two calls to
126
+ `update()`.
127
+
128
+ Args:
129
+ regex: Regular expression defining which statistics to
130
+ collect. The default is to collect everything.
131
+ keep_previous: Whether to retain the previous averages if no
132
+ scalars were collected on a given round
133
+ (default: False).
134
+ """
135
+ def __init__(self, regex='.*', keep_previous=False):
136
+ self._regex = re.compile(regex)
137
+ self._keep_previous = keep_previous
138
+ self._cumulative = dict()
139
+ self._moments = dict()
140
+ self.update()
141
+ self._moments.clear()
142
+
143
+ def names(self):
144
+ r"""Returns the names of all statistics broadcasted so far that
145
+ match the regular expression specified at construction time.
146
+ """
147
+ return [name for name in _counters if self._regex.fullmatch(name)]
148
+
149
+ def update(self):
150
+ r"""Copies current values of the internal counters to the
151
+ user-visible state and resets them for the next round.
152
+
153
+ If `keep_previous=True` was specified at construction time, the
154
+ operation is skipped for statistics that have received no scalars
155
+ since the last update, retaining their previous averages.
156
+
157
+ This method performs a number of GPU-to-CPU transfers and one
158
+ `torch.distributed.all_reduce()`. It is intended to be called
159
+ periodically in the main training loop, typically once every
160
+ N training steps.
161
+ """
162
+ if not self._keep_previous:
163
+ self._moments.clear()
164
+ for name, cumulative in _sync(self.names()):
165
+ if name not in self._cumulative:
166
+ self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
167
+ delta = cumulative - self._cumulative[name]
168
+ self._cumulative[name].copy_(cumulative)
169
+ if float(delta[0]) != 0:
170
+ self._moments[name] = delta
171
+
172
+ def _get_delta(self, name):
173
+ r"""Returns the raw moments that were accumulated for the given
174
+ statistic between the last two calls to `update()`, or zero if
175
+ no scalars were collected.
176
+ """
177
+ assert self._regex.fullmatch(name)
178
+ if name not in self._moments:
179
+ self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
180
+ return self._moments[name]
181
+
182
+ def num(self, name):
183
+ r"""Returns the number of scalars that were accumulated for the given
184
+ statistic between the last two calls to `update()`, or zero if
185
+ no scalars were collected.
186
+ """
187
+ delta = self._get_delta(name)
188
+ return int(delta[0])
189
+
190
+ def mean(self, name):
191
+ r"""Returns the mean of the scalars that were accumulated for the
192
+ given statistic between the last two calls to `update()`, or NaN if
193
+ no scalars were collected.
194
+ """
195
+ delta = self._get_delta(name)
196
+ if int(delta[0]) == 0:
197
+ return float('nan')
198
+ return float(delta[1] / delta[0])
199
+
200
+ def std(self, name):
201
+ r"""Returns the standard deviation of the scalars that were
202
+ accumulated for the given statistic between the last two calls to
203
+ `update()`, or NaN if no scalars were collected.
204
+ """
205
+ delta = self._get_delta(name)
206
+ if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
207
+ return float('nan')
208
+ if int(delta[0]) == 1:
209
+ return float(0)
210
+ mean = float(delta[1] / delta[0])
211
+ raw_var = float(delta[2] / delta[0])
212
+ return np.sqrt(max(raw_var - np.square(mean), 0))
213
+
214
+ def as_dict(self):
215
+ r"""Returns the averages accumulated between the last two calls to
216
+ `update()` as an `dnnlib.EasyDict`. The contents are as follows:
217
+
218
+ dnnlib.EasyDict(
219
+ NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
220
+ ...
221
+ )
222
+ """
223
+ stats = dnnlib.EasyDict()
224
+ for name in self.names():
225
+ stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
226
+ return stats
227
+
228
+ def __getitem__(self, name):
229
+ r"""Convenience getter.
230
+ `collector[name]` is a synonym for `collector.mean(name)`.
231
+ """
232
+ return self.mean(name)
233
+
234
+ #----------------------------------------------------------------------------
235
+
236
+ def _sync(names):
237
+ r"""Synchronize the global cumulative counters across devices and
238
+ processes. Called internally by `Collector.update()`.
239
+ """
240
+ if len(names) == 0:
241
+ return []
242
+ global _sync_called
243
+ _sync_called = True
244
+
245
+ # Check that all ranks have the same set of names.
246
+ if _sync_device is not None:
247
+ value = hash(tuple(tuple(ord(char) for char in name) for name in names))
248
+ other = torch.as_tensor(value, dtype=torch.int64, device=_sync_device)
249
+ torch.distributed.broadcast(tensor=other, src=0)
250
+ if value != int(other.cpu()):
251
+ raise ValueError('Training statistics are inconsistent between ranks')
252
+
253
+ # Collect deltas within current rank.
254
+ deltas = []
255
+ device = _sync_device if _sync_device is not None else torch.device('cpu')
256
+ for name in names:
257
+ delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
258
+ for counter in _counters[name].values():
259
+ delta.add_(counter.to(device))
260
+ counter.copy_(torch.zeros_like(counter))
261
+ deltas.append(delta)
262
+ deltas = torch.stack(deltas)
263
+
264
+ # Sum deltas across ranks.
265
+ if _sync_device is not None:
266
+ torch.distributed.all_reduce(deltas)
267
+
268
+ # Update cumulative values.
269
+ deltas = deltas.cpu()
270
+ for idx, name in enumerate(names):
271
+ if name not in _cumulative:
272
+ _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
273
+ _cumulative[name].add_(deltas[idx])
274
+
275
+ # Return name-value pairs.
276
+ return [(name, _cumulative[name]) for name in names]
277
+
278
+ #----------------------------------------------------------------------------
279
+ # Convenience.
280
+
281
+ default_collector = Collector()
282
+
283
+ #----------------------------------------------------------------------------