yu-val-weiss
commited on
Commit
·
995725e
1
Parent(s):
c26f589
fix probability logic
Browse files
blimp.py
CHANGED
@@ -296,21 +296,22 @@ def get_batch_probabilities(
|
|
296 |
with torch.no_grad():
|
297 |
outputs = model(**inputs)
|
298 |
|
299 |
-
|
|
|
300 |
|
301 |
-
#
|
302 |
-
log_probs = torch.nn.functional.log_softmax(
|
303 |
|
304 |
-
#
|
305 |
token_log_probs = torch.gather(log_probs, 2, labels.unsqueeze(-1)).squeeze(-1)
|
306 |
|
307 |
if batch_size > 1:
|
308 |
-
#
|
309 |
mask = (labels != tokenizer.pad_token_id).float()
|
310 |
token_log_probs *= mask
|
311 |
|
312 |
# sum log probabilities
|
313 |
-
sequence_log_probs =
|
314 |
|
315 |
probs.extend(sequence_log_probs.cpu().tolist())
|
316 |
|
|
|
296 |
with torch.no_grad():
|
297 |
outputs = model(**inputs)
|
298 |
|
299 |
+
logits = outputs.logits[..., :-1, :].contiguous()
|
300 |
+
labels = inputs.input_ids[..., 1:].contiguous()
|
301 |
|
302 |
+
# compute log probabilities
|
303 |
+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
304 |
|
305 |
+
# get per-token probability
|
306 |
token_log_probs = torch.gather(log_probs, 2, labels.unsqueeze(-1)).squeeze(-1)
|
307 |
|
308 |
if batch_size > 1:
|
309 |
+
# mask padding tokens
|
310 |
mask = (labels != tokenizer.pad_token_id).float()
|
311 |
token_log_probs *= mask
|
312 |
|
313 |
# sum log probabilities
|
314 |
+
sequence_log_probs = token_log_probs.sum(dim=1)
|
315 |
|
316 |
probs.extend(sequence_log_probs.cpu().tolist())
|
317 |
|