Spaces:
Running
Running
File size: 2,275 Bytes
e8ca4ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os
import sys
import time
import json
import torch
# Only do the function below if verbose
def logger(verbose):
def log(*msg):
if verbose: print(*msg)
return log
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
global last_time, begin_time
if current == 0:
begin_time = time.time() # Reset for new bar.
cur_time = time.time()
step_time = cur_time - last_time
last_time = cur_time
tot_time = cur_time - begin_time
L = []
L.append(' Step: %s' % format_time(step_time))
L.append(' | Tot: %s' % format_time(tot_time))
if msg:
L.append(' | ' + msg)
msg = ''.join(L)
sys.stdout.write(msg)
sys.stdout.write('\r')
#if current < total-1:
#
#else:
#sys.stdout.write('\n')
sys.stdout.flush()
def format_time(seconds):
days = int(seconds / 3600/24)
seconds = seconds - days*3600*24
hours = int(seconds / 3600)
seconds = seconds - hours*3600
minutes = int(seconds / 60)
seconds = seconds - minutes*60
secondsf = int(seconds)
seconds = seconds - secondsf
millis = int(seconds*1000)
f = ''
i = 1
if days > 0:
f += str(days) + 'D'
i += 1
if hours > 0 and i <= 2:
f += str(hours) + 'h'
i += 1
if minutes > 0 and i <= 2:
f += str(minutes) + 'm'
i += 1
if secondsf > 0 and i <= 2:
f += str(secondsf) + 's'
i += 1
if millis > 0 and i <= 2:
f += str(millis) + 'ms'
i += 1
if f == '':
f = '0ms'
return f
def save_vocab(data_loader, vocab_filename="config/vocab.json"):
"""Save vocabulary to a JSON file."""
vocab = {
'char_idx': data_loader.char_idx,
'char_list': data_loader.char_list
}
with open(vocab_filename, 'w') as f:
json.dump(vocab, f)
def load_vocab(vocab_filename='config/vocab.json'):
with open(vocab_filename, 'r') as f:
vocab = json.load(f)
return vocab['char_idx'], vocab['char_list']
def seq_to_tensor(seq, char_idx):
"""Convert sequence to tensor."""
out = torch.zeros(len(seq)).long()
for i, c in enumerate(seq):
out[i] = char_idx.index(c)
return out |