kevinwang676 commited on
Commit
638bd64
1 Parent(s): 57200a5

Create load_model.py

Browse files
Files changed (1) hide show
  1. load_model.py +936 -0
load_model.py ADDED
@@ -0,0 +1,936 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import ast
7
+ import collections
8
+ import contextlib
9
+ import inspect
10
+ import logging
11
+ import os
12
+ import re
13
+ import time
14
+ import traceback
15
+ from collections import OrderedDict
16
+ from pathlib import Path
17
+ from typing import Any, Dict, Optional, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ from fairseq.data import data_utils
22
+ from fairseq.dataclass.configs import CheckpointConfig
23
+ from fairseq.dataclass.utils import (
24
+ convert_namespace_to_omegaconf,
25
+ overwrite_args_by_name,
26
+ )
27
+ from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
28
+ from fairseq.file_io import PathManager
29
+ from fairseq.models import FairseqDecoder, FairseqEncoder
30
+ from omegaconf import DictConfig, OmegaConf, open_dict
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
36
+ from fairseq import meters
37
+
38
+ # only one worker should attempt to create the required dir
39
+ if trainer.data_parallel_rank == 0:
40
+ os.makedirs(cfg.save_dir, exist_ok=True)
41
+
42
+ prev_best = getattr(save_checkpoint, "best", val_loss)
43
+ if val_loss is not None:
44
+ best_function = max if cfg.maximize_best_checkpoint_metric else min
45
+ save_checkpoint.best = best_function(val_loss, prev_best)
46
+
47
+ if cfg.no_save:
48
+ return None
49
+
50
+ trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
51
+
52
+ if not trainer.should_save_checkpoint_on_current_rank:
53
+ if trainer.always_call_state_dict_during_save_checkpoint:
54
+ trainer.state_dict()
55
+ return None
56
+
57
+ write_timer = meters.StopwatchMeter()
58
+ write_timer.start()
59
+
60
+ epoch = epoch_itr.epoch
61
+ end_of_epoch = epoch_itr.end_of_epoch()
62
+ updates = trainer.get_num_updates()
63
+
64
+ logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
65
+
66
+ def is_better(a, b):
67
+ return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
68
+
69
+ suffix = trainer.checkpoint_suffix
70
+ checkpoint_conds = collections.OrderedDict()
71
+ checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
72
+ end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
73
+ )
74
+ checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
75
+ not end_of_epoch
76
+ and cfg.save_interval_updates > 0
77
+ and updates % cfg.save_interval_updates == 0
78
+ )
79
+ checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
80
+ not hasattr(save_checkpoint, "best")
81
+ or is_better(val_loss, save_checkpoint.best)
82
+ )
83
+ if val_loss is not None and cfg.keep_best_checkpoints > 0:
84
+ worst_best = getattr(save_checkpoint, "best", None)
85
+ chkpts = checkpoint_paths(
86
+ cfg.save_dir,
87
+ pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
88
+ cfg.best_checkpoint_metric, suffix
89
+ ),
90
+ )
91
+ if len(chkpts) > 0:
92
+ p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
93
+ worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
94
+ # add random digits to resolve ties
95
+ with data_utils.numpy_seed(epoch, updates, val_loss):
96
+ rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
97
+
98
+ checkpoint_conds[
99
+ "checkpoint.best_{}_{:.3f}{}{}.pt".format(
100
+ cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
101
+ )
102
+ ] = worst_best is None or is_better(val_loss, worst_best)
103
+ checkpoint_conds[
104
+ "checkpoint_last{}.pt".format(suffix)
105
+ ] = not cfg.no_last_checkpoints
106
+
107
+ extra_state = {
108
+ "train_iterator": epoch_itr.state_dict(),
109
+ "val_loss": val_loss,
110
+ }
111
+
112
+ # Going forward, different tasks could expose an API like this to dump all
113
+ # the checkpoint worthy attributes in a dictionary which then will be
114
+ # merged with the parent dictionary to create the "extra_state". This
115
+ # allows for an extensible yet simple design to checkpoint task level
116
+ # attributes
117
+ if hasattr(trainer.task, "get_checkpoint_dict"):
118
+ extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()}
119
+ logger.info(f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint")
120
+
121
+ if hasattr(save_checkpoint, "best"):
122
+ extra_state.update({"best": save_checkpoint.best})
123
+
124
+ checkpoints = [
125
+ os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
126
+ ]
127
+ saved_cp = None
128
+ if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
129
+ saved_cp = trainer.save_checkpoint(checkpoints[0], extra_state)
130
+ for cp in checkpoints[1:]:
131
+ if cfg.write_checkpoints_asynchronously:
132
+ # TODO[ioPath]: Need to implement a delayed asynchronous
133
+ # file copying/moving feature.
134
+ logger.warning(
135
+ f"ioPath is not copying {checkpoints[0]} to {cp} "
136
+ "since async write mode is on."
137
+ )
138
+ else:
139
+ assert PathManager.copy(
140
+ checkpoints[0], cp, overwrite=True
141
+ ), f"Failed to copy {checkpoints[0]} to {cp}"
142
+
143
+ write_timer.stop()
144
+ logger.info(
145
+ "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
146
+ checkpoints[0], epoch, updates, val_loss, write_timer.sum
147
+ )
148
+ )
149
+
150
+ if (
151
+ not end_of_epoch
152
+ and cfg.keep_interval_updates > 0
153
+ and trainer.should_save_checkpoint_on_current_rank
154
+ ):
155
+ # remove old checkpoints; checkpoints are sorted in descending order
156
+ if cfg.keep_interval_updates_pattern == -1:
157
+ checkpoints = checkpoint_paths(
158
+ cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
159
+ )
160
+ else:
161
+ checkpoints = checkpoint_paths(
162
+ cfg.save_dir,
163
+ pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
164
+ keep_match=True,
165
+ )
166
+ checkpoints = [
167
+ x[0]
168
+ for x in checkpoints
169
+ if x[1] % cfg.keep_interval_updates_pattern != 0
170
+ ]
171
+
172
+ for old_chk in checkpoints[cfg.keep_interval_updates :]:
173
+ if os.path.lexists(old_chk):
174
+ os.remove(old_chk)
175
+ elif PathManager.exists(old_chk):
176
+ PathManager.rm(old_chk)
177
+
178
+ if cfg.keep_last_epochs > 0 and trainer.should_save_checkpoint_on_current_rank:
179
+ # remove old epoch checkpoints; checkpoints are sorted in descending order
180
+ checkpoints = checkpoint_paths(
181
+ cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
182
+ )
183
+ for old_chk in checkpoints[cfg.keep_last_epochs :]:
184
+ if os.path.lexists(old_chk):
185
+ os.remove(old_chk)
186
+ elif PathManager.exists(old_chk):
187
+ PathManager.rm(old_chk)
188
+
189
+ if cfg.keep_best_checkpoints > 0 and trainer.should_save_checkpoint_on_current_rank:
190
+ # only keep the best N checkpoints according to validation metric
191
+ checkpoints = checkpoint_paths(
192
+ cfg.save_dir,
193
+ pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
194
+ cfg.best_checkpoint_metric, suffix
195
+ ),
196
+ )
197
+ if not cfg.maximize_best_checkpoint_metric:
198
+ checkpoints = checkpoints[::-1]
199
+ for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
200
+ if os.path.lexists(old_chk):
201
+ os.remove(old_chk)
202
+ elif PathManager.exists(old_chk):
203
+ PathManager.rm(old_chk)
204
+
205
+ return saved_cp
206
+
207
+
208
+ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
209
+ """
210
+ Load a checkpoint and restore the training iterator.
211
+
212
+ *passthrough_args* will be passed through to
213
+ ``trainer.get_train_iterator``.
214
+ """
215
+
216
+ reset_optimizer = cfg.reset_optimizer
217
+ reset_lr_scheduler = cfg.reset_lr_scheduler
218
+ optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
219
+ reset_meters = cfg.reset_meters
220
+ reset_dataloader = cfg.reset_dataloader
221
+
222
+ if cfg.finetune_from_model is not None and (
223
+ reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
224
+ ):
225
+ raise ValueError(
226
+ "--finetune-from-model can not be set together with either --reset-optimizer"
227
+ " or reset_lr_scheduler or reset_meters or reset_dataloader"
228
+ )
229
+
230
+ suffix = trainer.checkpoint_suffix
231
+ if (
232
+ cfg.restore_file == "checkpoint_last.pt"
233
+ ): # default value of restore_file is 'checkpoint_last.pt'
234
+ checkpoint_path = os.path.join(
235
+ cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
236
+ )
237
+ first_launch = not PathManager.exists(checkpoint_path)
238
+ if first_launch and getattr(cfg, "continue_once", None) is not None:
239
+ checkpoint_path = cfg.continue_once
240
+ elif cfg.finetune_from_model is not None and first_launch:
241
+ # if there is no last checkpoint to restore, start the finetune from pretrained model
242
+ # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
243
+ if PathManager.exists(cfg.finetune_from_model):
244
+ checkpoint_path = cfg.finetune_from_model
245
+ reset_optimizer = True
246
+ reset_lr_scheduler = True
247
+ reset_meters = True
248
+ reset_dataloader = True
249
+ logger.info(
250
+ f"loading pretrained model from {checkpoint_path}: "
251
+ "optimizer, lr scheduler, meters, dataloader will be reset"
252
+ )
253
+ else:
254
+ raise ValueError(
255
+ f"--finetune-from-model {cfg.finetune_from_model} does not exist"
256
+ )
257
+ elif suffix is not None:
258
+ checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
259
+ else:
260
+ checkpoint_path = cfg.restore_file
261
+
262
+ if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
263
+ raise ValueError(
264
+ "--finetune-from-model and --restore-file (non-default value) "
265
+ "can not be specified together: " + str(cfg)
266
+ )
267
+
268
+ extra_state = trainer.load_checkpoint(
269
+ checkpoint_path,
270
+ reset_optimizer,
271
+ reset_lr_scheduler,
272
+ optimizer_overrides,
273
+ reset_meters=reset_meters,
274
+ )
275
+
276
+ if (
277
+ extra_state is not None
278
+ and "best" in extra_state
279
+ and not reset_optimizer
280
+ and not reset_meters
281
+ ):
282
+ save_checkpoint.best = extra_state["best"]
283
+
284
+ if extra_state is not None and not reset_dataloader:
285
+ # restore iterator from checkpoint
286
+ itr_state = extra_state["train_iterator"]
287
+ epoch_itr = trainer.get_train_iterator(
288
+ epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
289
+ )
290
+ epoch_itr.load_state_dict(itr_state)
291
+
292
+ # Preload the checkpoint for the task
293
+ task_cp_dict = extra_state.get(trainer.task.__class__.__name__, {})
294
+ if task_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"):
295
+ trainer.task.set_checkpoint_dict(task_cp_dict)
296
+ else:
297
+ epoch_itr = trainer.get_train_iterator(
298
+ epoch=1, load_dataset=True, **passthrough_args
299
+ )
300
+
301
+ trainer.lr_step(epoch_itr.epoch)
302
+
303
+ return extra_state, epoch_itr
304
+
305
+
306
+ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
307
+ """Loads a checkpoint to CPU (with upgrading for backward compatibility).
308
+
309
+ If doing single-GPU training or if the checkpoint is only being loaded by at
310
+ most one process on each node (current default behavior is for only rank 0
311
+ to read the checkpoint from disk), load_on_all_ranks should be False to
312
+ avoid errors from torch.distributed not having been initialized or
313
+ torch.distributed.barrier() hanging.
314
+
315
+ If all processes on each node may be loading the checkpoint
316
+ simultaneously, load_on_all_ranks should be set to True to avoid I/O
317
+ conflicts.
318
+
319
+ There's currently no support for > 1 but < all processes loading the
320
+ checkpoint on each node.
321
+ """
322
+ local_path = PathManager.get_local_path(path)
323
+ # The locally cached file returned by get_local_path() may be stale for
324
+ # remote files that are periodically updated/overwritten (ex:
325
+ # checkpoint_last.pt) - so we remove the local copy, sync across processes
326
+ # (if needed), and then download a fresh copy.
327
+ if local_path != path and PathManager.path_requires_pathmanager(path):
328
+ try:
329
+ os.remove(local_path)
330
+ except FileNotFoundError:
331
+ # With potentially multiple processes removing the same file, the
332
+ # file being missing is benign (missing_ok isn't available until
333
+ # Python 3.8).
334
+ pass
335
+ if load_on_all_ranks:
336
+ torch.distributed.barrier()
337
+ local_path = PathManager.get_local_path(path)
338
+
339
+ with open(local_path, "rb") as f:
340
+ state = torch.load(f, map_location=torch.device("cpu"))
341
+
342
+ if "args" in state and state["args"] is not None and arg_overrides is not None:
343
+ args = state["args"]
344
+ for arg_name, arg_val in arg_overrides.items():
345
+ setattr(args, arg_name, arg_val)
346
+
347
+ if "cfg" in state and state["cfg"] is not None:
348
+
349
+ # hack to be able to set Namespace in dict config. this should be removed when we update to newer
350
+ # omegaconf version that supports object flags, or when we migrate all existing models
351
+ from omegaconf import __version__ as oc_version
352
+ from omegaconf import _utils
353
+
354
+ if oc_version < "2.2":
355
+ old_primitive = _utils.is_primitive_type
356
+ _utils.is_primitive_type = lambda _: True
357
+
358
+ state["cfg"] = OmegaConf.create(state["cfg"])
359
+
360
+ _utils.is_primitive_type = old_primitive
361
+ OmegaConf.set_struct(state["cfg"], True)
362
+ else:
363
+ state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
364
+
365
+ if arg_overrides is not None:
366
+ overwrite_args_by_name(state["cfg"], arg_overrides)
367
+
368
+ state = _upgrade_state_dict(state)
369
+ return state
370
+
371
+
372
+ def load_model_ensemble(
373
+ filenames,
374
+ arg_overrides: Optional[Dict[str, Any]] = None,
375
+ task=None,
376
+ strict=True,
377
+ suffix="",
378
+ num_shards=1,
379
+ state=None,
380
+ ):
381
+ """Loads an ensemble of models.
382
+
383
+ Args:
384
+ filenames (List[str]): checkpoint files to load
385
+ arg_overrides (Dict[str,Any], optional): override model args that
386
+ were used during model training
387
+ task (fairseq.tasks.FairseqTask, optional): task to use for loading
388
+ """
389
+ assert not (
390
+ strict and num_shards > 1
391
+ ), "Cannot load state dict with strict=True and checkpoint shards > 1"
392
+ ensemble, args, _task = load_model_ensemble_and_task(
393
+ filenames,
394
+ arg_overrides,
395
+ task,
396
+ strict,
397
+ suffix,
398
+ num_shards,
399
+ state,
400
+ )
401
+ return ensemble, args
402
+
403
+
404
+ def get_maybe_sharded_checkpoint_filename(
405
+ filename: str, suffix: str, shard_idx: int, num_shards: int
406
+ ) -> str:
407
+ orig_filename = filename
408
+ filename = filename.replace(".pt", suffix + ".pt")
409
+ fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
410
+ model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
411
+ if PathManager.exists(fsdp_filename):
412
+ return fsdp_filename
413
+ elif num_shards > 1:
414
+ return model_parallel_filename
415
+ else:
416
+ return filename
417
+
418
+
419
+ def load_model_ensemble_and_task(
420
+ filenames,
421
+ arg_overrides: Optional[Dict[str, Any]] = None,
422
+ task=None,
423
+ strict=True,
424
+ suffix="",
425
+ num_shards=1,
426
+ state=None,
427
+ ):
428
+ assert state is None or len(filenames) == 1
429
+
430
+ from fairseq import tasks
431
+
432
+ assert not (
433
+ strict and num_shards > 1
434
+ ), "Cannot load state dict with strict=True and checkpoint shards > 1"
435
+ ensemble = []
436
+ cfg = None
437
+ for filename in filenames:
438
+ orig_filename = filename
439
+ model_shard_state = {"shard_weights": [], "shard_metadata": []}
440
+ assert num_shards > 0
441
+ st = time.time()
442
+ for shard_idx in range(num_shards):
443
+ filename = get_maybe_sharded_checkpoint_filename(
444
+ orig_filename, suffix, shard_idx, num_shards
445
+ )
446
+
447
+ if not PathManager.exists(filename):
448
+ raise IOError("Model file not found: {}".format(filename))
449
+ if state is None:
450
+ state = load_checkpoint_to_cpu(filename, arg_overrides)
451
+ if "args" in state and state["args"] is not None:
452
+ cfg = convert_namespace_to_omegaconf(state["args"])
453
+ elif "cfg" in state and state["cfg"] is not None:
454
+ cfg = state["cfg"]
455
+ else:
456
+ raise RuntimeError(
457
+ f"Neither args nor cfg exist in state keys = {state.keys()}"
458
+ )
459
+
460
+ if task is None:
461
+ task = tasks.setup_task(cfg.task, from_checkpoint=True)
462
+
463
+ if "task_state" in state:
464
+ task.load_state_dict(state["task_state"])
465
+
466
+ argspec = inspect.getfullargspec(task.build_model)
467
+
468
+ if "fsdp_metadata" in state and num_shards > 1:
469
+ model_shard_state["shard_weights"].append(state["model"])
470
+ model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
471
+ # check FSDP import before the code goes too far
472
+ if not has_FSDP:
473
+ raise ImportError(
474
+ "Cannot find FullyShardedDataParallel. "
475
+ "Please install fairscale with: pip install fairscale"
476
+ )
477
+ if shard_idx == num_shards - 1:
478
+ consolidated_model_state = FSDP.consolidate_shard_weights(
479
+ shard_weights=model_shard_state["shard_weights"],
480
+ shard_metadata=model_shard_state["shard_metadata"],
481
+ )
482
+ if "from_checkpoint" in argspec.args:
483
+ model = task.build_model(cfg.model, from_checkpoint=True)
484
+ else:
485
+ model = task.build_model(cfg.model)
486
+ if (
487
+ "optimizer_history" in state
488
+ and len(state["optimizer_history"]) > 0
489
+ and "num_updates" in state["optimizer_history"][-1]
490
+ ):
491
+ model.set_num_updates(
492
+ state["optimizer_history"][-1]["num_updates"]
493
+ )
494
+ model.load_state_dict(
495
+ consolidated_model_state, strict=strict, model_cfg=cfg.model
496
+ )
497
+ else:
498
+ # model parallel checkpoint or unsharded checkpoint
499
+ # support old external tasks
500
+
501
+ if "from_checkpoint" in argspec.args:
502
+ model = task.build_model(cfg.model, from_checkpoint=True)
503
+ else:
504
+ model = task.build_model(cfg.model)
505
+ if (
506
+ "optimizer_history" in state
507
+ and len(state["optimizer_history"]) > 0
508
+ and "num_updates" in state["optimizer_history"][-1]
509
+ ):
510
+ model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
511
+ model.load_state_dict(
512
+ state["model"], strict=strict, model_cfg=cfg.model
513
+ )
514
+
515
+ # reset state so it gets loaded for the next model in ensemble
516
+ state = None
517
+ if shard_idx % 10 == 0 and shard_idx > 0:
518
+ elapsed = time.time() - st
519
+ logger.info(
520
+ f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
521
+ )
522
+
523
+ # build model for ensemble
524
+ ensemble.append(model)
525
+ return ensemble, cfg, task
526
+
527
+
528
+ def load_model_ensemble_and_task_from_hf_hub(
529
+ model_id,
530
+ cache_dir: Optional[str] = None,
531
+ arg_overrides: Optional[Dict[str, Any]] = None,
532
+ **kwargs: Any,
533
+ ):
534
+ try:
535
+ from huggingface_hub import snapshot_download
536
+ except ImportError:
537
+ raise ImportError(
538
+ "You need to install huggingface_hub to use `load_from_hf_hub`. "
539
+ "See https://pypi.org/project/huggingface-hub/ for installation."
540
+ )
541
+
542
+ library_name = "fairseq"
543
+ cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix()
544
+ cache_dir = snapshot_download(
545
+ model_id, cache_dir=cache_dir, library_name=library_name, **kwargs
546
+ )
547
+
548
+ _arg_overrides = arg_overrides or {}
549
+ _arg_overrides["data"] = cache_dir
550
+ return load_model_ensemble_and_task(
551
+ [p.as_posix() for p in Path(cache_dir).glob("*.pt")],
552
+ arg_overrides=_arg_overrides,
553
+ )
554
+
555
+
556
+ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
557
+ """Retrieves all checkpoints found in `path` directory.
558
+
559
+ Checkpoints are identified by matching filename to the specified pattern. If
560
+ the pattern contains groups, the result will be sorted by the first group in
561
+ descending order.
562
+ """
563
+ pt_regexp = re.compile(pattern)
564
+ files = PathManager.ls(path)
565
+
566
+ entries = []
567
+ for i, f in enumerate(files):
568
+ m = pt_regexp.fullmatch(f)
569
+ if m is not None:
570
+ idx = float(m.group(1)) if len(m.groups()) > 0 else i
571
+ entries.append((idx, m.group(0)))
572
+ if keep_match:
573
+ return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
574
+ else:
575
+ return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
576
+
577
+
578
+ def torch_persistent_save(obj, filename, async_write: bool = False):
579
+ if async_write:
580
+ with PathManager.opena(filename, "wb") as f:
581
+ _torch_persistent_save(obj, f)
582
+ else:
583
+ if PathManager.supports_rename(filename):
584
+ # do atomic save
585
+ with PathManager.open(filename + ".tmp", "wb") as f:
586
+ _torch_persistent_save(obj, f)
587
+ PathManager.rename(filename + ".tmp", filename)
588
+ else:
589
+ # fallback to non-atomic save
590
+ with PathManager.open(filename, "wb") as f:
591
+ _torch_persistent_save(obj, f)
592
+
593
+
594
+ def _torch_persistent_save(obj, f):
595
+ if isinstance(f, str):
596
+ with PathManager.open(f, "wb") as h:
597
+ torch_persistent_save(obj, h)
598
+ return
599
+ for i in range(3):
600
+ try:
601
+ return torch.save(obj, f)
602
+ except Exception:
603
+ if i == 2:
604
+ logger.error(traceback.format_exc())
605
+ raise
606
+ else:
607
+ time.sleep(2.5)
608
+
609
+
610
+ def _upgrade_state_dict(state):
611
+ """Helper for upgrading old model checkpoints."""
612
+
613
+ # add optimizer_history
614
+ if "optimizer_history" not in state:
615
+ state["optimizer_history"] = [
616
+ {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
617
+ ]
618
+ state["last_optimizer_state"] = state["optimizer"]
619
+ del state["optimizer"]
620
+ del state["best_loss"]
621
+ # move extra_state into sub-dictionary
622
+ if "epoch" in state and "extra_state" not in state:
623
+ state["extra_state"] = {
624
+ "epoch": state["epoch"],
625
+ "batch_offset": state["batch_offset"],
626
+ "val_loss": state["val_loss"],
627
+ }
628
+ del state["epoch"]
629
+ del state["batch_offset"]
630
+ del state["val_loss"]
631
+ # reduce optimizer history's memory usage (only keep the last state)
632
+ if "optimizer" in state["optimizer_history"][-1]:
633
+ state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
634
+ for optim_hist in state["optimizer_history"]:
635
+ del optim_hist["optimizer"]
636
+ # record the optimizer class name
637
+ if "optimizer_name" not in state["optimizer_history"][-1]:
638
+ state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
639
+ # move best_loss into lr_scheduler_state
640
+ if "lr_scheduler_state" not in state["optimizer_history"][-1]:
641
+ state["optimizer_history"][-1]["lr_scheduler_state"] = {
642
+ "best": state["optimizer_history"][-1]["best_loss"]
643
+ }
644
+ del state["optimizer_history"][-1]["best_loss"]
645
+ # keep track of number of updates
646
+ if "num_updates" not in state["optimizer_history"][-1]:
647
+ state["optimizer_history"][-1]["num_updates"] = 0
648
+ # use stateful training data iterator
649
+ if "train_iterator" not in state["extra_state"]:
650
+ state["extra_state"]["train_iterator"] = {
651
+ "epoch": state["extra_state"].get("epoch", 0),
652
+ "iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
653
+ }
654
+
655
+ # backward compatibility, cfg updates
656
+ if "args" in state and state["args"] is not None:
657
+ # old model checkpoints may not have separate source/target positions
658
+ if hasattr(state["args"], "max_positions") and not hasattr(
659
+ state["args"], "max_source_positions"
660
+ ):
661
+ state["args"].max_source_positions = state["args"].max_positions
662
+ state["args"].max_target_positions = state["args"].max_positions
663
+ # default to translation task
664
+ if not hasattr(state["args"], "task"):
665
+ state["args"].task = "translation"
666
+ # --raw-text and --lazy-load are deprecated
667
+ if getattr(state["args"], "raw_text", False):
668
+ state["args"].dataset_impl = "raw"
669
+ elif getattr(state["args"], "lazy_load", False):
670
+ state["args"].dataset_impl = "lazy"
671
+ # epochs start at 1
672
+ if state["extra_state"]["train_iterator"] is not None:
673
+ state["extra_state"]["train_iterator"]["epoch"] = max(
674
+ state["extra_state"]["train_iterator"].get("epoch", 1), 1
675
+ )
676
+ # --remove-bpe ==> --postprocess
677
+ if hasattr(state["args"], "remove_bpe"):
678
+ state["args"].post_process = state["args"].remove_bpe
679
+ # --min-lr ==> --stop-min-lr
680
+ if hasattr(state["args"], "min_lr"):
681
+ state["args"].stop_min_lr = state["args"].min_lr
682
+ del state["args"].min_lr
683
+ # binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
684
+ if hasattr(state["args"], "criterion") and state["args"].criterion in [
685
+ "binary_cross_entropy",
686
+ "kd_binary_cross_entropy",
687
+ ]:
688
+ state["args"].criterion = "wav2vec"
689
+ # remove log_keys if it's None (criteria will supply a default value of [])
690
+ if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
691
+ delattr(state["args"], "log_keys")
692
+ # speech_pretraining => audio pretraining
693
+ if (
694
+ hasattr(state["args"], "task")
695
+ and state["args"].task == "speech_pretraining"
696
+ ):
697
+ state["args"].task = "audio_pretraining"
698
+ # audio_cpc => wav2vec
699
+ if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
700
+ state["args"].arch = "wav2vec"
701
+ # convert legacy float learning rate to List[float]
702
+ if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
703
+ state["args"].lr = [state["args"].lr]
704
+ # convert task data arg to a string instead of List[string]
705
+ if (
706
+ hasattr(state["args"], "data")
707
+ and isinstance(state["args"].data, list)
708
+ and len(state["args"].data) > 0
709
+ ):
710
+ state["args"].data = state["args"].data[0]
711
+
712
+ state["cfg"] = convert_namespace_to_omegaconf(state["args"])
713
+
714
+ if "cfg" in state and state["cfg"] is not None:
715
+ cfg = state["cfg"]
716
+ with open_dict(cfg):
717
+ # any upgrades for Hydra-based configs
718
+ if (
719
+ "task" in cfg
720
+ and "eval_wer_config" in cfg.task
721
+ and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
722
+ ):
723
+ cfg.task.eval_wer_config.print_alignment = "hard"
724
+ if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
725
+ cfg.generation.print_alignment = (
726
+ "hard" if cfg.generation.print_alignment else None
727
+ )
728
+ if (
729
+ "model" in cfg
730
+ and "w2v_args" in cfg.model
731
+ and cfg.model.w2v_args is not None
732
+ and (
733
+ hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
734
+ )
735
+ and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
736
+ and cfg.model.w2v_args.task.eval_wer_config is not None
737
+ and isinstance(
738
+ cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
739
+ )
740
+ ):
741
+ cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
742
+
743
+ return state
744
+
745
+
746
+ def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
747
+ """Prune the given state_dict if desired for LayerDrop
748
+ (https://arxiv.org/abs/1909.11556).
749
+
750
+ Training with LayerDrop allows models to be robust to pruning at inference
751
+ time. This function prunes state_dict to allow smaller models to be loaded
752
+ from a larger model and re-maps the existing state_dict for this to occur.
753
+
754
+ It's called by functions that load models from checkpoints and does not
755
+ need to be called directly.
756
+ """
757
+ arch = None
758
+ if model_cfg is not None:
759
+ arch = (
760
+ model_cfg._name
761
+ if isinstance(model_cfg, DictConfig)
762
+ else getattr(model_cfg, "arch", None)
763
+ )
764
+
765
+ if not model_cfg or arch is None or arch == "ptt_transformer":
766
+ # args should not be none, but don't crash if it is.
767
+ return state_dict
768
+
769
+ encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
770
+ decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
771
+
772
+ if not encoder_layers_to_keep and not decoder_layers_to_keep:
773
+ return state_dict
774
+
775
+ # apply pruning
776
+ logger.info(
777
+ "Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
778
+ )
779
+
780
+ def create_pruning_pass(layers_to_keep, layer_name):
781
+ keep_layers = sorted(
782
+ int(layer_string) for layer_string in layers_to_keep.split(",")
783
+ )
784
+ mapping_dict = {}
785
+ for i in range(len(keep_layers)):
786
+ mapping_dict[str(keep_layers[i])] = str(i)
787
+
788
+ regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
789
+ return {"substitution_regex": regex, "mapping_dict": mapping_dict}
790
+
791
+ pruning_passes = []
792
+ if encoder_layers_to_keep:
793
+ pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
794
+ if decoder_layers_to_keep:
795
+ pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
796
+
797
+ new_state_dict = {}
798
+ for layer_name in state_dict.keys():
799
+ match = re.search(r"\.layers\.(\d+)\.", layer_name)
800
+ # if layer has no number in it, it is a supporting layer, such as an
801
+ # embedding
802
+ if not match:
803
+ new_state_dict[layer_name] = state_dict[layer_name]
804
+ continue
805
+
806
+ # otherwise, layer should be pruned.
807
+ original_layer_number = match.group(1)
808
+ # figure out which mapping dict to replace from
809
+ for pruning_pass in pruning_passes:
810
+ if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
811
+ "substitution_regex"
812
+ ].search(layer_name):
813
+ new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
814
+ substitution_match = pruning_pass["substitution_regex"].search(
815
+ layer_name
816
+ )
817
+ new_state_key = (
818
+ layer_name[: substitution_match.start(1)]
819
+ + new_layer_number
820
+ + layer_name[substitution_match.end(1) :]
821
+ )
822
+ new_state_dict[new_state_key] = state_dict[layer_name]
823
+
824
+ # Since layers are now pruned, *_layers_to_keep are no longer needed.
825
+ # This is more of "It would make it work fix" rather than a proper fix.
826
+ if isinstance(model_cfg, DictConfig):
827
+ context = open_dict(model_cfg)
828
+ else:
829
+ context = contextlib.ExitStack()
830
+ with context:
831
+ if hasattr(model_cfg, "encoder_layers_to_keep"):
832
+ model_cfg.encoder_layers_to_keep = None
833
+ if hasattr(model_cfg, "decoder_layers_to_keep"):
834
+ model_cfg.decoder_layers_to_keep = None
835
+
836
+ return new_state_dict
837
+
838
+
839
+ def load_pretrained_component_from_model(
840
+ component: Union[FairseqEncoder, FairseqDecoder],
841
+ checkpoint: str,
842
+ strict: bool = True,
843
+ ):
844
+ """
845
+ Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
846
+ provided `component` object. If state_dict fails to load, there may be a
847
+ mismatch in the architecture of the corresponding `component` found in the
848
+ `checkpoint` file.
849
+ """
850
+ if not PathManager.exists(checkpoint):
851
+ raise IOError("Model file not found: {}".format(checkpoint))
852
+ state = load_checkpoint_to_cpu(checkpoint)
853
+ if isinstance(component, FairseqEncoder):
854
+ component_type = "encoder"
855
+ elif isinstance(component, FairseqDecoder):
856
+ component_type = "decoder"
857
+ else:
858
+ raise ValueError(
859
+ "component to load must be either a FairseqEncoder or "
860
+ "FairseqDecoder. Loading other component types are not supported."
861
+ )
862
+ component_state_dict = OrderedDict()
863
+ for key in state["model"].keys():
864
+ if key.startswith(component_type):
865
+ # encoder.input_layers.0.0.weight --> input_layers.0.0.weight
866
+ component_subkey = key[len(component_type) + 1 :]
867
+ component_state_dict[component_subkey] = state["model"][key]
868
+ component.load_state_dict(component_state_dict, strict=strict)
869
+ return component
870
+
871
+
872
+ def verify_checkpoint_directory(save_dir: str) -> None:
873
+ if not os.path.exists(save_dir):
874
+ os.makedirs(save_dir, exist_ok=True)
875
+ temp_file_path = os.path.join(save_dir, "dummy")
876
+ try:
877
+ with open(temp_file_path, "w"):
878
+ pass
879
+ except OSError as e:
880
+ logger.warning(
881
+ "Unable to access checkpoint save directory: {}".format(save_dir)
882
+ )
883
+ raise e
884
+ else:
885
+ os.remove(temp_file_path)
886
+
887
+
888
+ def save_ema_as_checkpoint(src_path, dst_path):
889
+ state = load_ema_from_checkpoint(src_path)
890
+ torch_persistent_save(state, dst_path)
891
+
892
+
893
+ def load_ema_from_checkpoint(fpath):
894
+ """Loads exponential moving averaged (EMA) checkpoint from input and
895
+ returns a model with ema weights.
896
+
897
+ Args:
898
+ fpath: A string path of checkpoint to load from.
899
+
900
+ Returns:
901
+ A dict of string keys mapping to various values. The 'model' key
902
+ from the returned dict should correspond to an OrderedDict mapping
903
+ string parameter names to torch Tensors.
904
+ """
905
+ params_dict = collections.OrderedDict()
906
+ new_state = None
907
+
908
+ with PathManager.open(fpath, "rb") as f:
909
+ new_state = torch.load(
910
+ f,
911
+ map_location=(
912
+ lambda s, _: torch.serialization.default_restore_location(s, "cpu")
913
+ ),
914
+ )
915
+
916
+ # EMA model is stored in a separate "extra state"
917
+ model_params = new_state["extra_state"]["ema"]
918
+
919
+ for key in list(model_params.keys()):
920
+ p = model_params[key]
921
+ if isinstance(p, torch.HalfTensor):
922
+ p = p.float()
923
+ if key not in params_dict:
924
+ params_dict[key] = p.clone()
925
+ # NOTE: clone() is needed in case of p is a shared parameter
926
+ else:
927
+ raise ValueError("Key {} is repeated in EMA model params.".format(key))
928
+
929
+ if len(params_dict) == 0:
930
+ raise ValueError(
931
+ f"Input checkpoint path '{fpath}' does not contain "
932
+ "ema model weights, is this model trained with EMA?"
933
+ )
934
+
935
+ new_state["model"] = params_dict
936
+ return new_state