Xinyoumeng233hu commited on
Commit
d81ea2d
·
1 Parent(s): 5b67b19

Update arithmetic.py

Browse files
Files changed (1) hide show
  1. arithmetic.py +2 -2
arithmetic.py CHANGED
@@ -3,7 +3,7 @@ import torch.nn.functional as F
3
 
4
  from utils import limit_past, kl, entropy, bits2int, int2bits, is_sent_finish, num_same_from_beg
5
 
6
- def encode_arithmetic(model, enc, message, context, finish_sent=False, device='cuda', temp=1.0, precision=16, topk=50000):
7
 
8
  context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
9
 
@@ -122,7 +122,7 @@ def encode_arithmetic(model, enc, message, context, finish_sent=False, device='c
122
 
123
  return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
124
 
125
- def decode_arithmetic(model, enc, text, context, device='cuda', temp=1.0, precision=16, topk=50000):
126
  # inp is a list of token indices
127
  # context is a list of token indices
128
  inp = enc.encode(text)
 
3
 
4
  from utils import limit_past, kl, entropy, bits2int, int2bits, is_sent_finish, num_same_from_beg
5
 
6
+ def encode_arithmetic(model, enc, message, context, finish_sent=False, device='cpu', temp=1.0, precision=16, topk=50000):
7
 
8
  context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
9
 
 
122
 
123
  return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
124
 
125
+ def decode_arithmetic(model, enc, text, context, device='cpu', temp=1.0, precision=16, topk=50000):
126
  # inp is a list of token indices
127
  # context is a list of token indices
128
  inp = enc.encode(text)