File size: 2,406 Bytes
01bb3bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from __future__ import print_function, absolute_import
import json
import os
import sys
# import moxing as mox
import os.path as osp
import shutil

import torch
from torch.nn import Parameter

from .osutils import mkdir_if_missing

from config import get_args
global_args = get_args(sys.argv[1:])

if global_args.run_on_remote:
  import moxing as mox


def read_json(fpath):
  with open(fpath, 'r') as f:
    obj = json.load(f)
  return obj


def write_json(obj, fpath):
  mkdir_if_missing(osp.dirname(fpath))
  with open(fpath, 'w') as f:
    json.dump(obj, f, indent=4, separators=(',', ': '))


def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
  print('=> saving checkpoint ', fpath)
  if global_args.run_on_remote:
    dir_name = osp.dirname(fpath)
    if not mox.file.exists(dir_name):
      mox.file.make_dirs(dir_name)
      print('=> makding dir ', dir_name)
    local_path = "local_checkpoint.pth.tar"
    torch.save(state, local_path)
    mox.file.copy(local_path, fpath)
    if is_best:
      mox.file.copy(local_path, osp.join(dir_name, 'model_best.pth.tar'))
  else:
    mkdir_if_missing(osp.dirname(fpath))
    torch.save(state, fpath)
    if is_best:
      shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar'))


def load_checkpoint(fpath):
  if global_args.run_on_remote:
    mox.file.shift('os', 'mox')
    checkpoint = torch.load(fpath)
    print("=> Loaded checkpoint '{}'".format(fpath))
    return checkpoint
  else:
    load_path = fpath

    if osp.isfile(load_path):
      checkpoint = torch.load(load_path)
      print("=> Loaded checkpoint '{}'".format(load_path))
      return checkpoint
    else:
      raise ValueError("=> No checkpoint found at '{}'".format(load_path))


def copy_state_dict(state_dict, model, strip=None):
  tgt_state = model.state_dict()
  copied_names = set()
  for name, param in state_dict.items():
    if strip is not None and name.startswith(strip):
      name = name[len(strip):]
    if name not in tgt_state:
      continue
    if isinstance(param, Parameter):
      param = param.data
    if param.size() != tgt_state[name].size():
      print('mismatch:', name, param.size(), tgt_state[name].size())
      continue
    tgt_state[name].copy_(param)
    copied_names.add(name)

  missing = set(tgt_state.keys()) - copied_names
  if len(missing) > 0:
      print("missing keys in state_dict:", missing)

  return model