Spaces:
Runtime error
Runtime error
ronald
commited on
Commit
·
a958a99
1
Parent(s):
b2af956
init
Browse files- local_coh_ppl.py +24 -7
local_coh_ppl.py
CHANGED
@@ -165,7 +165,14 @@ class LocalCohPPL(evaluate.Measurement):
|
|
165 |
tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
|
166 |
|
167 |
model.config.max_length = 512 if "scibert" in model_id else model.config.max_length
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
loss_fct = CrossEntropyLoss(reduction="none")
|
171 |
|
@@ -187,11 +194,21 @@ class LocalCohPPL(evaluate.Measurement):
|
|
187 |
|
188 |
encoded_texts = encodings["input_ids"]
|
189 |
attn_masks = encodings["attention_mask"]
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
# tokenize by sentence
|
197 |
for pred in batch_sents:
|
@@ -210,7 +227,7 @@ class LocalCohPPL(evaluate.Measurement):
|
|
210 |
labels = encoded_texts
|
211 |
|
212 |
with torch.no_grad():
|
213 |
-
out_logits = model(
|
214 |
|
215 |
shift_logits = out_logits[..., :-1, :].contiguous()
|
216 |
shift_labels = labels[..., 1:].contiguous()
|
|
|
165 |
tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
|
166 |
|
167 |
model.config.max_length = 512 if "scibert" in model_id else model.config.max_length
|
168 |
+
if add_start_token:
|
169 |
+
# leave room for <BOS> token to be added:
|
170 |
+
assert (
|
171 |
+
tokenizer.bos_token is not None
|
172 |
+
), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
|
173 |
+
max_tokenized_len = model.config.max_length - 1
|
174 |
+
else:
|
175 |
+
max_tokenized_len = model.config.max_length
|
176 |
|
177 |
loss_fct = CrossEntropyLoss(reduction="none")
|
178 |
|
|
|
194 |
|
195 |
encoded_texts = encodings["input_ids"]
|
196 |
attn_masks = encodings["attention_mask"]
|
197 |
+
|
198 |
+
# check that each input is long enough:
|
199 |
+
if add_start_token:
|
200 |
+
assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long."
|
201 |
+
else:
|
202 |
+
assert torch.all(
|
203 |
+
torch.ge(attn_masks.sum(1), 2)
|
204 |
+
), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."
|
205 |
+
|
206 |
+
if add_start_token:
|
207 |
+
bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_texts.size(dim=0)).to(device)
|
208 |
+
encoded_texts = torch.cat([bos_tokens_tensor, encoded_texts], dim=1)
|
209 |
+
attn_masks = torch.cat(
|
210 |
+
[torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_masks], dim=1
|
211 |
+
)
|
212 |
|
213 |
# tokenize by sentence
|
214 |
for pred in batch_sents:
|
|
|
227 |
labels = encoded_texts
|
228 |
|
229 |
with torch.no_grad():
|
230 |
+
out_logits = model(encoded_texts, attention_mask=attn_masks).logits
|
231 |
|
232 |
shift_logits = out_logits[..., :-1, :].contiguous()
|
233 |
shift_labels = labels[..., 1:].contiguous()
|