yu-val-weiss commited on
Commit
995725e
·
1 Parent(s): c26f589

fix probability logic

Browse files
Files changed (1) hide show
  1. blimp.py +7 -6
blimp.py CHANGED
@@ -296,21 +296,22 @@ def get_batch_probabilities(
296
  with torch.no_grad():
297
  outputs = model(**inputs)
298
 
299
- labels = inputs.input_ids
 
300
 
301
- # Compute log probabilities
302
- log_probs = torch.nn.functional.log_softmax(outputs.logits, dim=-1)
303
 
304
- # Get probability of each actual token
305
  token_log_probs = torch.gather(log_probs, 2, labels.unsqueeze(-1)).squeeze(-1)
306
 
307
  if batch_size > 1:
308
- # Create attention mask for padding
309
  mask = (labels != tokenizer.pad_token_id).float()
310
  token_log_probs *= mask
311
 
312
  # sum log probabilities
313
- sequence_log_probs = (token_log_probs).sum(dim=1)
314
 
315
  probs.extend(sequence_log_probs.cpu().tolist())
316
 
 
296
  with torch.no_grad():
297
  outputs = model(**inputs)
298
 
299
+ logits = outputs.logits[..., :-1, :].contiguous()
300
+ labels = inputs.input_ids[..., 1:].contiguous()
301
 
302
+ # compute log probabilities
303
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
304
 
305
+ # get per-token probability
306
  token_log_probs = torch.gather(log_probs, 2, labels.unsqueeze(-1)).squeeze(-1)
307
 
308
  if batch_size > 1:
309
+ # mask padding tokens
310
  mask = (labels != tokenizer.pad_token_id).float()
311
  token_log_probs *= mask
312
 
313
  # sum log probabilities
314
+ sequence_log_probs = token_log_probs.sum(dim=1)
315
 
316
  probs.extend(sequence_log_probs.cpu().tolist())
317