shreyasvaidya's picture
Upload folder using huggingface_hub
01bb3bb verified
from __future__ import print_function, absolute_import
import json
import os
import sys
# import moxing as mox
import os.path as osp
import shutil
import torch
from torch.nn import Parameter
from .osutils import mkdir_if_missing
from config import get_args
global_args = get_args(sys.argv[1:])
if global_args.run_on_remote:
import moxing as mox
def read_json(fpath):
with open(fpath, 'r') as f:
obj = json.load(f)
return obj
def write_json(obj, fpath):
mkdir_if_missing(osp.dirname(fpath))
with open(fpath, 'w') as f:
json.dump(obj, f, indent=4, separators=(',', ': '))
def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
print('=> saving checkpoint ', fpath)
if global_args.run_on_remote:
dir_name = osp.dirname(fpath)
if not mox.file.exists(dir_name):
mox.file.make_dirs(dir_name)
print('=> makding dir ', dir_name)
local_path = "local_checkpoint.pth.tar"
torch.save(state, local_path)
mox.file.copy(local_path, fpath)
if is_best:
mox.file.copy(local_path, osp.join(dir_name, 'model_best.pth.tar'))
else:
mkdir_if_missing(osp.dirname(fpath))
torch.save(state, fpath)
if is_best:
shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar'))
def load_checkpoint(fpath):
if global_args.run_on_remote:
mox.file.shift('os', 'mox')
checkpoint = torch.load(fpath)
print("=> Loaded checkpoint '{}'".format(fpath))
return checkpoint
else:
load_path = fpath
if osp.isfile(load_path):
checkpoint = torch.load(load_path)
print("=> Loaded checkpoint '{}'".format(load_path))
return checkpoint
else:
raise ValueError("=> No checkpoint found at '{}'".format(load_path))
def copy_state_dict(state_dict, model, strip=None):
tgt_state = model.state_dict()
copied_names = set()
for name, param in state_dict.items():
if strip is not None and name.startswith(strip):
name = name[len(strip):]
if name not in tgt_state:
continue
if isinstance(param, Parameter):
param = param.data
if param.size() != tgt_state[name].size():
print('mismatch:', name, param.size(), tgt_state[name].size())
continue
tgt_state[name].copy_(param)
copied_names.add(name)
missing = set(tgt_state.keys()) - copied_names
if len(missing) > 0:
print("missing keys in state_dict:", missing)
return model