Xinyoumeng233hu commited on
Commit
d6a5fdd
·
1 Parent(s): d81ea2d

Update block_baseline.py

Browse files
Files changed (1) hide show
  1. block_baseline.py +2 -2
block_baseline.py CHANGED
@@ -23,7 +23,7 @@ def get_bins(vocab_size, block_size):
23
 
24
  return bin2words, words2bin
25
 
26
- def encode_block(model, enc, message, context, block_size, bin2words, words2bin, finish_sent=False, device='cuda'):
27
  length = len(message)
28
 
29
  context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
@@ -96,7 +96,7 @@ def encode_block(model, enc, message, context, block_size, bin2words, words2bin,
96
 
97
  return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
98
 
99
- def decode_block(model, enc, text, context, block_size, bin2words, words2bin, device='cuda'):
100
  # inp is a list of token indices
101
  # context is a list of token indices
102
  inp = enc.encode(text)
 
23
 
24
  return bin2words, words2bin
25
 
26
+ def encode_block(model, enc, message, context, block_size, bin2words, words2bin, finish_sent=False, device='cpu'):
27
  length = len(message)
28
 
29
  context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
 
96
 
97
  return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
98
 
99
+ def decode_block(model, enc, text, context, block_size, bin2words, words2bin, device='cpu'):
100
  # inp is a list of token indices
101
  # context is a list of token indices
102
  inp = enc.encode(text)