Spaces:
Running
on
Zero
Running
on
Zero
File size: 14,652 Bytes
0108542 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 |
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from accelerate import Accelerator
from transformers import default_data_collator
from collections import defaultdict
from tqdm import tqdm
import numpy as np
def is_not_number(s):
try:
float(s) # Try converting the string to a float
return False # If conversion is successful, it's a number
except ValueError:
return True # If conversion fails, it's not a number
def get_contexts_ending_with_word(word, dataset):
result_contexts = []
word_len = len(word)
# Iterate over the dataset
for example in dataset:
text = example["text"]
# Find all occurrences of the word in the text
start = 0
while True:
idx = text.find(word, start)
if idx == -1:
break
# Ensure that the word is isolated (not a substring of another word)
if (idx == 0 or not text[idx - 1].isalnum()) and (
idx + word_len == len(text) or not text[idx + word_len].isalnum()):
# Text ends with the word
result_contexts.append(text[:idx + word_len].strip())
start = idx + word_len
return result_contexts
def get_texts_containing_word(words, dataset):
result_texts = []
words_set = set(words)
# Iterate over the dataset
for example in dataset:
if words_set.intersection(set(example["text"].split())):
result_texts.append(example["text"])
return result_texts
def compute_topk_token_rank(logits, labels, k=1000):
# Get the top-k predicted logits and their indices
topk_logits, topk_indices = torch.topk(logits, k, dim=-1)
# Expand the labels for comparison
labels_expanded = labels.unsqueeze(-1).expand_as(topk_indices)
# Check if the label token is within the top-k predictions
rank_in_topk = (topk_indices == labels_expanded).nonzero(as_tuple=False)
# Create a rank tensor initialized with k (max rank is k)
ranks = torch.full(labels.shape, k, dtype=torch.long, device=logits.device)
# For labels in top-k, set the rank accordingly
ranks[rank_in_topk[:, 0], rank_in_topk[:, 1]] = rank_in_topk[:, 2] + 1
return ranks
def count_tokens_in_dataset(dataset, tokenizer, text_column='text'):
def tokenize_and_count(examples):
return {'num_tokens': [len(tokenizer(ex).input_ids) for ex in examples[text_column]]}
tokenized_dataset = dataset.map(tokenize_and_count, batched=True, remove_columns=dataset.column_names)
total_tokens = sum(tokenized_dataset['num_tokens'])
return total_tokens
def filter_single_token_words(array, tokenizer, add_space_prefix_for_lower=True):
def _is_multi_token(word):
if add_space_prefix_for_lower and word[0].islower():
word = " " + word
return len(tokenizer.encode(word, add_special_tokens=False))
token_counts = array.apply(_is_multi_token)
mask = token_counts > 1
return array[mask], token_counts
# TODO make clearer what's its use
def get_last_zero_in_every_seq_mask(tensor):
# Find where consecutive zeros end
zero_mask = (tensor == 0)
diff = torch.diff(zero_mask.int(), dim=1)
last_zero_mask = torch.cat([diff, torch.ones(tensor.size(0), 1, dtype=diff.dtype).to(tensor.device)], dim=1) == -1
# Create the output
output = 1 - tensor
output[zero_mask & ~last_zero_mask] = 0
return output
def get_first_zero_in_every_seq_mask(tensor):
# Identify where consecutive zeros begin
zero_mask = (tensor == 0)
diff = torch.diff(zero_mask.int(), dim=1, prepend=torch.zeros(tensor.size(0), 1, dtype=torch.int).to(tensor.device))
first_zero_mask = diff == 1 # Marks the beginning of each sequence of zeros
# Create the output
output = 1 - tensor
output[zero_mask & ~first_zero_mask] = 0
return output
def _add_start_token(batch, tokenizer):
bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * batch["input_ids"].size(dim=0)).to(batch["input_ids"].device)
batch["input_ids"] = torch.cat([bos_tokens_tensor, batch["input_ids"]], dim=1)
batch["attention_mask"] = torch.cat(
[torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(batch["attention_mask"].device), batch["attention_mask"]], dim=1)
return batch
def _ignore_new_words_in_attention_mask(shift_attention_mask_batch, shift_labels, new_token_ids=None, replaced_token_seqs_by_len=None):
# Ignore token_ids of new vocabulary words in shift_labels and shift_logits
if new_token_ids is not None:
ignore_mask = torch.isin(shift_labels, new_token_ids)
shift_attention_mask_batch = shift_attention_mask_batch * (~ignore_mask).long()
# Ignore multi-token sequences of that were replaced with a single token
if replaced_token_seqs_by_len is not None:
# Create a mask that will be updated where sequences match
ignore_mask = shift_attention_mask_batch.clone() # Clone the attention mask to modify it
# Loop over sequences in skip_token_seqs
for seq_len, seqs in replaced_token_seqs_by_len.items():
# Create a sliding window of the same size as the skip_seq and check for matches
for i in range(shift_labels.size(1) - seq_len + 1):
# Check if the sequence matches at position i
window = shift_labels[:, i:i + seq_len]
curr_mask = torch.all(window.unsqueeze(1) == seqs.unsqueeze(0), dim=-1)
if curr_mask.any():
# Zero out the ignore mask for the length of the sequence
ignore_mask[curr_mask.any(dim=-1), i:i + seq_len] = 0
# Apply the ignore mask to the attention mask
shift_attention_mask_batch *= ignore_mask
return shift_attention_mask_batch, ignore_mask
# TODO consider not aggregating results here, to enable metrics for specific words
def compute_metrics(
logits, labels, attention_mask,
compute_target_metrics=True, compute_subsequent_metrics=True, compute_perplexity=False,
return_successful_targets=False,
original_labels=None, original_logits=None,
debug=False):
target_results = dict() # will hold metrics for all the new words we add or their original tokenization
background_results = dict() # will hold metrics for all background tokens, i.e., not the ones we add or replace
overall_results = dict() # will hold metrics for all tokens
successful_targets = None # will hold list of target tokens successfully predicted
if compute_subsequent_metrics:
# prepare labels and attentions masks for computing metrics only for the 1st tokens following the new words
subsequent_labels = labels[:, 1:]
subsequent_attention_mask = get_last_zero_in_every_seq_mask(attention_mask[..., :-1].contiguous())
subsequent_attention_mask_bool = subsequent_attention_mask == 1
attention_mask_bool = attention_mask == 1
overall_mask_bool = attention_mask_bool
if compute_target_metrics:
target_mask = get_first_zero_in_every_seq_mask(attention_mask)
target_mask_bool = target_mask == 1
overall_mask_bool = attention_mask_bool | target_mask_bool
if compute_perplexity:
background_results["perplexity"] = torch.exp(
(F.cross_entropy(logits.transpose(1, 2), labels, reduction="none") * attention_mask).sum(1)
/ attention_mask.sum(1)
).mean().detach().cpu().numpy()
top1 = logits.argmax(dim=-1)
if original_logits is not None:
orig_top1 = original_logits.argmax(dim=-1)
if compute_target_metrics:
target_results["top1_acc"] = ((labels == top1)[target_mask_bool]).detach().cpu().numpy()
if original_labels is not None:
target_results["sum_top1_acc"] = (
((original_labels == top1) | (labels == top1))[target_mask_bool]).detach().cpu().numpy()
if original_logits is not None:
target_results["orig_top1_acc"] = (
(original_labels == orig_top1)[target_mask_bool]).detach().cpu().numpy()
if return_successful_targets:
successful_targets = (labels[(labels == top1) & target_mask_bool]).detach().cpu().numpy()
background_results["top1_acc"] = ((
labels == top1)[attention_mask_bool]).detach().cpu().numpy()
if compute_subsequent_metrics:
background_results["subsequent_top1_acc"] = ((subsequent_labels == top1[:, 1:])[subsequent_attention_mask_bool]).detach().cpu().numpy()
if original_logits is not None:
background_results["orig_top1_acc"] = (
(original_labels == orig_top1)[attention_mask_bool]).detach().cpu().numpy()
if compute_subsequent_metrics:
background_results["orig_subsequent_top1_acc"] = (
(subsequent_labels == orig_top1[:, 1:])[subsequent_attention_mask_bool]).detach().cpu().numpy()
overall_results["top1_acc"] = ((labels == top1))[overall_mask_bool].detach().cpu().numpy()
if original_labels is not None:
overall_results["sum_top1_acc"] = (
((original_labels == top1) | (labels == top1)))[overall_mask_bool].detach().cpu().numpy()
if original_logits is not None:
overall_results["orig_top1_acc"] = (
(original_labels == orig_top1)[overall_mask_bool]).detach().cpu().numpy()
if debug:
import pdb; pdb.set_trace()
return background_results, target_results, overall_results, successful_targets
def eval_next_word_prediction(
model, tokenizer, lm_dataset, accelerator=None,
batch_size: int = 4,
new_token_ids=None, replaced_token_seqs_by_len=None,
new_token_to_original_first_token=None,
max_length: int = 256,
drop_last: bool = True,
eval_max_samples: int = None,
eval_shuffle_samples: bool = False,
reduction="none",
):
if accelerator is None:
accelerator = Accelerator()
model.eval()
if tokenizer.bos_token is not None and max_length:
add_start_token = True
else:
add_start_token = False
data_collator = default_data_collator
if eval_max_samples:
eval_idx = range(len(lm_dataset), min(eval_max_samples, len(lm_dataset)))
if eval_shuffle_samples:
eval_idx = np.random.choice(len(lm_dataset), min(eval_max_samples, len(lm_dataset)))
lm_dataset = lm_dataset.select(eval_idx)
# Create data loaders
eval_dataloader = DataLoader(
lm_dataset, collate_fn=data_collator, batch_size=batch_size, drop_last=drop_last, shuffle=False,
)
eval_dataloader = accelerator.prepare(eval_dataloader)
model.eval()
if new_token_ids is not None:
new_token_ids = torch.tensor(new_token_ids).to(model.device)
if replaced_token_seqs_by_len is not None:
replaced_token_seqs_by_len = {token_length: torch.tensor(skip_token_seqs).to(model.device) for token_length, skip_token_seqs in replaced_token_seqs_by_len.items() if len(skip_token_seqs) > 0}
if new_token_to_original_first_token is not None:
# Convert the mapping into a tensor for efficient indexing, create a mapping tensor that defaults to identity
new_token_to_orig_first_mapping_tensor = torch.arange(len(tokenizer), device=model.device)
new_token_to_orig_first_mapping_tensor[torch.tensor(list(new_token_to_original_first_token.keys()), device=model.device)] = \
torch.tensor(list(new_token_to_original_first_token.values()), device=model.device)
target_metrics = defaultdict(list)
background_metrics = defaultdict(list)
overall_metrics = defaultdict(list)
# run eval and compute metrics
for batch_i, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader), miniters=10, desc="Evaluating vocabulary..."):
if add_start_token:
batch = _add_start_token(batch, tokenizer)
labels = batch["input_ids"]
attn_mask = batch["attention_mask"]
batch.pop("labels")
with torch.no_grad():
outputs = model(**batch)
out_logits = outputs.logits
shift_logits = out_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
shift_attention_mask_batch, ignore_mask = \
_ignore_new_words_in_attention_mask(
shift_attention_mask_batch, shift_labels, new_token_ids, replaced_token_seqs_by_len)
original_labels = None if new_token_to_original_first_token is None \
else new_token_to_orig_first_mapping_tensor[shift_labels]
original_logits = None if new_token_ids is None else torch.cat([shift_logits[:, :, :min(new_token_ids)], shift_logits[:, :, max(new_token_ids)+1:]], dim=-1)
background_results, target_results, overall_results, successful_targets = \
compute_metrics(
shift_logits, shift_labels, shift_attention_mask_batch,
original_labels=original_labels, original_logits=original_logits, compute_perplexity=True)
for metric_name, metric_value in target_results.items():
target_metrics[metric_name].append(np.array(metric_value))
for metric_name, metric_value in background_results.items():
background_metrics[metric_name].append(metric_value)
for metric_name, metric_value in overall_results.items():
overall_metrics[metric_name].append(metric_value)
eval_dataloader = accelerator.free_memory(eval_dataloader)
def _concat_func(x):
if isinstance(x, np.ndarray) and len(x.shape) > 1:
x = np.concat(x)
elif isinstance(x, (list, tuple)) and len(x) > 1:
if isinstance(x[0], np.ndarray) and len(x[0].shape) == 0:
x = np.array(x)
else:
x = np.concat(x)
return x
# apply reduction
reduce_func = _concat_func
if reduction == 'mean':
reduce_func = lambda x: np.mean(_concat_func(x)).item()
for metric_name, metric_value in target_metrics.items():
target_metrics[metric_name] = reduce_func(metric_value)
for metric_name, metric_value in background_metrics.items():
background_metrics[metric_name] = reduce_func(metric_value)
for metric_name, metric_value in overall_metrics.items():
overall_metrics[metric_name] = reduce_func(metric_value)
return background_metrics, target_metrics, overall_metrics
|