Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import collections | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import torch.optim as optim | |
import os | |
import torch.nn.functional as F | |
import six | |
from six.moves import cPickle | |
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] | |
bad_endings += ['UNK', 'has', 'and', 'more'] | |
def pickle_load(f): | |
""" Load a pickle. | |
Parameters | |
---------- | |
f: file-like object | |
""" | |
if six.PY3: | |
return cPickle.load(f, encoding='latin-1') | |
else: | |
return cPickle.load(f) | |
def pickle_dump(obj, f): | |
""" Dump a pickle. | |
Parameters | |
---------- | |
obj: pickled object | |
f: file-like object | |
""" | |
if six.PY3: | |
return cPickle.dump(obj, f, protocol=2) | |
else: | |
return cPickle.dump(obj, f) | |
# modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/comm.py | |
def serialize_to_tensor(data): | |
device = torch.device("cpu") | |
buffer = cPickle.dumps(data) | |
storage = torch.ByteStorage.from_buffer(buffer) | |
tensor = torch.ByteTensor(storage).to(device=device) | |
return tensor | |
def deserialize(tensor): | |
buffer = tensor.cpu().numpy().tobytes() | |
return cPickle.loads(buffer) | |
# Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. | |
def decode_sequence(ix_to_word, seq): | |
N, D = seq.size() | |
out = [] | |
for i in range(N): | |
txt = '' | |
for j in range(D): | |
ix = seq[i,j] | |
if ix > 0 : | |
if j >= 1: | |
txt = txt + ' ' | |
txt = txt + ix_to_word[str(ix.item())] | |
else: | |
break | |
if int(os.getenv('REMOVE_BAD_ENDINGS', '0')): | |
flag = 0 | |
words = txt.split(' ') | |
for j in range(len(words)): | |
if words[-j-1] not in bad_endings: | |
flag = -j | |
break | |
txt = ' '.join(words[0:len(words)+flag]) | |
out.append(txt.replace('@@ ', '')) | |
return out | |
def save_checkpoint(opt, model, infos, optimizer, histories=None, append=''): | |
if len(append) > 0: | |
append = '_' + append | |
# if checkpoint_path doesn't exist | |
if not os.path.isdir(opt.checkpoint_path): | |
os.makedirs(opt.checkpoint_path) | |
checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append)) | |
torch.save(model.state_dict(), checkpoint_path) | |
print("model saved to {}".format(checkpoint_path)) | |
optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append)) | |
torch.save(optimizer.state_dict(), optimizer_path) | |
with open(os.path.join(opt.checkpoint_path, 'infos%s.pkl' %(append)), 'wb') as f: | |
pickle_dump(infos, f) | |
if histories: | |
with open(os.path.join(opt.checkpoint_path, 'histories%s.pkl' %(append)), 'wb') as f: | |
pickle_dump(histories, f) | |
def set_lr(optimizer, lr): | |
for group in optimizer.param_groups: | |
group['lr'] = lr | |
def get_lr(optimizer): | |
for group in optimizer.param_groups: | |
return group['lr'] | |
def build_optimizer(params, opt): | |
if opt.optim == 'rmsprop': | |
return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay) | |
elif opt.optim == 'adagrad': | |
return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay) | |
elif opt.optim == 'sgd': | |
return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay) | |
elif opt.optim == 'sgdm': | |
return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay) | |
elif opt.optim == 'sgdmom': | |
return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True) | |
elif opt.optim == 'adam': | |
return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay) | |
elif opt.optim == 'adamw': | |
return optim.AdamW(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay) | |
else: | |
raise Exception("bad option opt.optim: {}".format(opt.optim)) | |
def penalty_builder(penalty_config): | |
if penalty_config == '': | |
return lambda x,y: y | |
pen_type, alpha = penalty_config.split('_') | |
alpha = float(alpha) | |
if pen_type == 'wu': | |
return lambda x,y: length_wu(x,y,alpha) | |
if pen_type == 'avg': | |
return lambda x,y: length_average(x,y,alpha) | |
def length_wu(length, logprobs, alpha=0.): | |
""" | |
NMT length re-ranking score from | |
"Google's Neural Machine Translation System" :cite:`wu2016google`. | |
""" | |
modifier = (((5 + length) ** alpha) / | |
((5 + 1) ** alpha)) | |
return (logprobs / modifier) | |
def length_average(length, logprobs, alpha=0.): | |
""" | |
Returns the average probability of tokens in a sequence. | |
""" | |
return logprobs / length | |
class NoamOpt(object): | |
"Optim wrapper that implements rate." | |
def __init__(self, model_size, factor, warmup, optimizer): | |
self.optimizer = optimizer | |
self._step = 0 | |
self.warmup = warmup | |
self.factor = factor | |
self.model_size = model_size | |
self._rate = 0 | |
def step(self): | |
"Update parameters and rate" | |
self._step += 1 | |
rate = self.rate() | |
for p in self.optimizer.param_groups: | |
p['lr'] = rate | |
self._rate = rate | |
self.optimizer.step() | |
def rate(self, step = None): | |
"Implement `lrate` above" | |
if step is None: | |
step = self._step | |
return self.factor * \ | |
(self.model_size ** (-0.5) * | |
min(step ** (-0.5), step * self.warmup ** (-1.5))) | |
def __getattr__(self, name): | |
return getattr(self.optimizer, name) | |
def state_dict(self): | |
state_dict = self.optimizer.state_dict() | |
state_dict['_step'] = self._step | |
return state_dict | |
def load_state_dict(self, state_dict): | |
if '_step' in state_dict: | |
self._step = state_dict['_step'] | |
del state_dict['_step'] | |
self.optimizer.load_state_dict(state_dict) | |
class ReduceLROnPlateau(object): | |
"Optim wrapper that implements rate." | |
def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08): | |
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps) | |
self.optimizer = optimizer | |
self.current_lr = get_lr(optimizer) | |
def step(self): | |
"Update parameters and rate" | |
self.optimizer.step() | |
def scheduler_step(self, val): | |
self.scheduler.step(val) | |
self.current_lr = get_lr(self.optimizer) | |
def state_dict(self): | |
return {'current_lr':self.current_lr, | |
'scheduler_state_dict': self.scheduler.state_dict(), | |
'optimizer_state_dict': self.optimizer.state_dict()} | |
def load_state_dict(self, state_dict): | |
if 'current_lr' not in state_dict: | |
# it's normal optimizer | |
self.optimizer.load_state_dict(state_dict) | |
set_lr(self.optimizer, self.current_lr) # use the lr fromt the option | |
else: | |
# it's a schduler | |
self.current_lr = state_dict['current_lr'] | |
self.scheduler.load_state_dict(state_dict['scheduler_state_dict']) | |
self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) | |
# current_lr is actually useless in this case | |
def rate(self, step = None): | |
"Implement `lrate` above" | |
if step is None: | |
step = self._step | |
return self.factor * \ | |
(self.model_size ** (-0.5) * | |
min(step ** (-0.5), step * self.warmup ** (-1.5))) | |
def __getattr__(self, name): | |
return getattr(self.optimizer, name) | |
def get_std_opt(model, optim_func='adam', factor=1, warmup=2000): | |
# return NoamOpt(model.tgt_embed[0].d_model, 2, 4000, | |
# torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) | |
optim_func = dict(adam=torch.optim.Adam, | |
adamw=torch.optim.AdamW)[optim_func] | |
return NoamOpt(model.d_model, factor, warmup, | |
optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) |