Spaces:
Runtime error
Runtime error
ronald
commited on
Commit
·
52b805e
1
Parent(s):
4ebdc9c
init
Browse files- local_coh_ppl.py +2 -5
local_coh_ppl.py
CHANGED
@@ -215,16 +215,13 @@ class LocalCohPPL(evaluate.Measurement):
|
|
215 |
for pred in batch_sents:
|
216 |
ss = pred.split("\n")
|
217 |
sslens = [len(tokenizer(y,add_special_tokens=False,padding=False).input_ids) for y in ss]
|
218 |
-
offset =
|
219 |
sspos = [offset]
|
220 |
for sslen in sslens:
|
221 |
offset = min(offset + sslen,511)
|
222 |
sspos.append(offset)
|
223 |
sent_tok_lens.append(sspos)
|
224 |
|
225 |
-
print("[compute ppl] check ...")
|
226 |
-
pdb.set_trace()
|
227 |
-
|
228 |
labels = encoded_texts
|
229 |
|
230 |
with torch.no_grad():
|
@@ -232,7 +229,7 @@ class LocalCohPPL(evaluate.Measurement):
|
|
232 |
|
233 |
shift_logits = out_logits[..., :-1, :].contiguous()
|
234 |
shift_labels = labels[..., 1:].contiguous()
|
235 |
-
shift_attention_mask_batch =
|
236 |
|
237 |
loss_out = loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch
|
238 |
perplexity_all = torch.exp(
|
|
|
215 |
for pred in batch_sents:
|
216 |
ss = pred.split("\n")
|
217 |
sslens = [len(tokenizer(y,add_special_tokens=False,padding=False).input_ids) for y in ss]
|
218 |
+
offset = int(add_start_token)
|
219 |
sspos = [offset]
|
220 |
for sslen in sslens:
|
221 |
offset = min(offset + sslen,511)
|
222 |
sspos.append(offset)
|
223 |
sent_tok_lens.append(sspos)
|
224 |
|
|
|
|
|
|
|
225 |
labels = encoded_texts
|
226 |
|
227 |
with torch.no_grad():
|
|
|
229 |
|
230 |
shift_logits = out_logits[..., :-1, :].contiguous()
|
231 |
shift_labels = labels[..., 1:].contiguous()
|
232 |
+
shift_attention_mask_batch = attn_masks[..., 1:].contiguous()
|
233 |
|
234 |
loss_out = loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch
|
235 |
perplexity_all = torch.exp(
|