ronald commited on
Commit
a958a99
·
1 Parent(s): b2af956
Files changed (1) hide show
  1. local_coh_ppl.py +24 -7
local_coh_ppl.py CHANGED
@@ -165,7 +165,14 @@ class LocalCohPPL(evaluate.Measurement):
165
  tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
166
 
167
  model.config.max_length = 512 if "scibert" in model_id else model.config.max_length
168
- max_tokenized_len = model.config.max_length - 1
 
 
 
 
 
 
 
169
 
170
  loss_fct = CrossEntropyLoss(reduction="none")
171
 
@@ -187,11 +194,21 @@ class LocalCohPPL(evaluate.Measurement):
187
 
188
  encoded_texts = encodings["input_ids"]
189
  attn_masks = encodings["attention_mask"]
190
- bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_texts.size(dim=0)).to(device)
191
- encoded_texts = torch.cat([bos_tokens_tensor, encoded_texts], dim=1)
192
- attn_masks = torch.cat(
193
- [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_masks], dim=1
194
- )
 
 
 
 
 
 
 
 
 
 
195
 
196
  # tokenize by sentence
197
  for pred in batch_sents:
@@ -210,7 +227,7 @@ class LocalCohPPL(evaluate.Measurement):
210
  labels = encoded_texts
211
 
212
  with torch.no_grad():
213
- out_logits = model(encoded_batch, attention_mask=attn_mask).logits
214
 
215
  shift_logits = out_logits[..., :-1, :].contiguous()
216
  shift_labels = labels[..., 1:].contiguous()
 
165
  tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
166
 
167
  model.config.max_length = 512 if "scibert" in model_id else model.config.max_length
168
+ if add_start_token:
169
+ # leave room for <BOS> token to be added:
170
+ assert (
171
+ tokenizer.bos_token is not None
172
+ ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
173
+ max_tokenized_len = model.config.max_length - 1
174
+ else:
175
+ max_tokenized_len = model.config.max_length
176
 
177
  loss_fct = CrossEntropyLoss(reduction="none")
178
 
 
194
 
195
  encoded_texts = encodings["input_ids"]
196
  attn_masks = encodings["attention_mask"]
197
+
198
+ # check that each input is long enough:
199
+ if add_start_token:
200
+ assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long."
201
+ else:
202
+ assert torch.all(
203
+ torch.ge(attn_masks.sum(1), 2)
204
+ ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."
205
+
206
+ if add_start_token:
207
+ bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_texts.size(dim=0)).to(device)
208
+ encoded_texts = torch.cat([bos_tokens_tensor, encoded_texts], dim=1)
209
+ attn_masks = torch.cat(
210
+ [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_masks], dim=1
211
+ )
212
 
213
  # tokenize by sentence
214
  for pred in batch_sents:
 
227
  labels = encoded_texts
228
 
229
  with torch.no_grad():
230
+ out_logits = model(encoded_texts, attention_mask=attn_masks).logits
231
 
232
  shift_logits = out_logits[..., :-1, :].contiguous()
233
  shift_labels = labels[..., 1:].contiguous()