Spaces:
Sleeping
Sleeping
Commit
·
d6a5fdd
1
Parent(s):
d81ea2d
Update block_baseline.py
Browse files- block_baseline.py +2 -2
block_baseline.py
CHANGED
@@ -23,7 +23,7 @@ def get_bins(vocab_size, block_size):
|
|
23 |
|
24 |
return bin2words, words2bin
|
25 |
|
26 |
-
def encode_block(model, enc, message, context, block_size, bin2words, words2bin, finish_sent=False, device='
|
27 |
length = len(message)
|
28 |
|
29 |
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
@@ -96,7 +96,7 @@ def encode_block(model, enc, message, context, block_size, bin2words, words2bin,
|
|
96 |
|
97 |
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
|
98 |
|
99 |
-
def decode_block(model, enc, text, context, block_size, bin2words, words2bin, device='
|
100 |
# inp is a list of token indices
|
101 |
# context is a list of token indices
|
102 |
inp = enc.encode(text)
|
|
|
23 |
|
24 |
return bin2words, words2bin
|
25 |
|
26 |
+
def encode_block(model, enc, message, context, block_size, bin2words, words2bin, finish_sent=False, device='cpu'):
|
27 |
length = len(message)
|
28 |
|
29 |
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
|
|
96 |
|
97 |
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
|
98 |
|
99 |
+
def decode_block(model, enc, text, context, block_size, bin2words, words2bin, device='cpu'):
|
100 |
# inp is a list of token indices
|
101 |
# context is a list of token indices
|
102 |
inp = enc.encode(text)
|