ronald commited on
Commit
52b805e
·
1 Parent(s): 4ebdc9c
Files changed (1) hide show
  1. local_coh_ppl.py +2 -5
local_coh_ppl.py CHANGED
@@ -215,16 +215,13 @@ class LocalCohPPL(evaluate.Measurement):
215
  for pred in batch_sents:
216
  ss = pred.split("\n")
217
  sslens = [len(tokenizer(y,add_special_tokens=False,padding=False).input_ids) for y in ss]
218
- offset = 0
219
  sspos = [offset]
220
  for sslen in sslens:
221
  offset = min(offset + sslen,511)
222
  sspos.append(offset)
223
  sent_tok_lens.append(sspos)
224
 
225
- print("[compute ppl] check ...")
226
- pdb.set_trace()
227
-
228
  labels = encoded_texts
229
 
230
  with torch.no_grad():
@@ -232,7 +229,7 @@ class LocalCohPPL(evaluate.Measurement):
232
 
233
  shift_logits = out_logits[..., :-1, :].contiguous()
234
  shift_labels = labels[..., 1:].contiguous()
235
- shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
236
 
237
  loss_out = loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch
238
  perplexity_all = torch.exp(
 
215
  for pred in batch_sents:
216
  ss = pred.split("\n")
217
  sslens = [len(tokenizer(y,add_special_tokens=False,padding=False).input_ids) for y in ss]
218
+ offset = int(add_start_token)
219
  sspos = [offset]
220
  for sslen in sslens:
221
  offset = min(offset + sslen,511)
222
  sspos.append(offset)
223
  sent_tok_lens.append(sspos)
224
 
 
 
 
225
  labels = encoded_texts
226
 
227
  with torch.no_grad():
 
229
 
230
  shift_logits = out_logits[..., :-1, :].contiguous()
231
  shift_labels = labels[..., 1:].contiguous()
232
+ shift_attention_mask_batch = attn_masks[..., 1:].contiguous()
233
 
234
  loss_out = loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch
235
  perplexity_all = torch.exp(