Spaces:
Sleeping
Sleeping
Commit
·
d81ea2d
1
Parent(s):
5b67b19
Update arithmetic.py
Browse files- arithmetic.py +2 -2
arithmetic.py
CHANGED
@@ -3,7 +3,7 @@ import torch.nn.functional as F
|
|
3 |
|
4 |
from utils import limit_past, kl, entropy, bits2int, int2bits, is_sent_finish, num_same_from_beg
|
5 |
|
6 |
-
def encode_arithmetic(model, enc, message, context, finish_sent=False, device='
|
7 |
|
8 |
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
9 |
|
@@ -122,7 +122,7 @@ def encode_arithmetic(model, enc, message, context, finish_sent=False, device='c
|
|
122 |
|
123 |
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
|
124 |
|
125 |
-
def decode_arithmetic(model, enc, text, context, device='
|
126 |
# inp is a list of token indices
|
127 |
# context is a list of token indices
|
128 |
inp = enc.encode(text)
|
|
|
3 |
|
4 |
from utils import limit_past, kl, entropy, bits2int, int2bits, is_sent_finish, num_same_from_beg
|
5 |
|
6 |
+
def encode_arithmetic(model, enc, message, context, finish_sent=False, device='cpu', temp=1.0, precision=16, topk=50000):
|
7 |
|
8 |
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
9 |
|
|
|
122 |
|
123 |
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
|
124 |
|
125 |
+
def decode_arithmetic(model, enc, text, context, device='cpu', temp=1.0, precision=16, topk=50000):
|
126 |
# inp is a list of token indices
|
127 |
# context is a list of token indices
|
128 |
inp = enc.encode(text)
|