File size: 6,009 Bytes
2f044c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import json
import os
from collections import defaultdict
from typing import Any, Dict, Iterable, List, Optional, Union
import numpy as np
import transformers as tr
from tqdm import tqdm
class HardNegativesManager:
def __init__(
self,
tokenizer: tr.PreTrainedTokenizer,
data: Union[List[Dict], os.PathLike, Dict[int, List]] = None,
max_length: int = 64,
batch_size: int = 1000,
lazy: bool = False,
) -> None:
self._db: dict = None
self.tokenizer = tokenizer
if data is None:
self._db = {}
else:
if isinstance(data, Dict):
self._db = data
elif isinstance(data, os.PathLike):
with open(data) as f:
self._db = json.load(f)
else:
raise ValueError(
f"Data type {type(data)} not supported, only Dict and os.PathLike are supported."
)
# add the tokenizer to the class for future use
self.tokenizer = tokenizer
# invert the db to have a passage -> sample_idx mapping
self._passage_db = defaultdict(set)
for sample_idx, passages in self._db.items():
for passage in passages:
self._passage_db[passage].add(sample_idx)
self._passage_hard_negatives = {}
if not lazy:
# create a dictionary of passage -> hard_negative mapping
batch_size = min(batch_size, len(self._passage_db))
unique_passages = list(self._passage_db.keys())
for i in tqdm(
range(0, len(unique_passages), batch_size),
desc="Tokenizing Hard Negatives",
):
batch = unique_passages[i : i + batch_size]
tokenized_passages = self.tokenizer(
batch,
max_length=max_length,
truncation=True,
)
for i, passage in enumerate(batch):
self._passage_hard_negatives[passage] = {
k: tokenized_passages[k][i] for k in tokenized_passages.keys()
}
def __len__(self) -> int:
return len(self._db)
def __getitem__(self, idx: int) -> Dict:
return self._db[idx]
def __iter__(self):
for sample in self._db:
yield sample
def __contains__(self, idx: int) -> bool:
return idx in self._db
def get(self, idx: int) -> List[str]:
"""Get the hard negatives for a given sample index."""
if idx not in self._db:
raise ValueError(f"Sample index {idx} not in the database.")
passages = self._db[idx]
output = []
for passage in passages:
if passage not in self._passage_hard_negatives:
self._passage_hard_negatives[passage] = self._tokenize(passage)
output.append(self._passage_hard_negatives[passage])
return output
def _tokenize(self, passage: str) -> Dict:
return self.tokenizer(passage, max_length=self.max_length, truncation=True)
class NegativeSampler:
def __init__(
self, num_elements: int, probabilities: Optional[Union[List, np.ndarray]] = None
):
if not isinstance(probabilities, np.ndarray):
probabilities = np.array(probabilities)
if probabilities is None:
# probabilities should sum to 1
probabilities = np.random.random(num_elements)
probabilities /= np.sum(probabilities)
self.probabilities = probabilities
def __call__(
self,
sample_size: int,
num_samples: int = 1,
probabilities: np.array = None,
exclude: List[int] = None,
) -> np.array:
"""
Fast sampling of `sample_size` elements from `num_elements` elements.
The sampling is done by randomly shifting the probabilities and then
finding the smallest of the negative numbers. This is much faster than
sampling from a multinomial distribution.
Args:
sample_size (`int`):
number of elements to sample
num_samples (`int`, optional):
number of samples to draw. Defaults to 1.
probabilities (`np.array`, optional):
probabilities of each element. Defaults to None.
exclude (`List[int]`, optional):
indices of elements to exclude. Defaults to None.
Returns:
`np.array`: array of sampled indices
"""
if probabilities is None:
probabilities = self.probabilities
if exclude is not None:
probabilities[exclude] = 0
# re-normalize?
# probabilities /= np.sum(probabilities)
# replicate probabilities as many times as `num_samples`
replicated_probabilities = np.tile(probabilities, (num_samples, 1))
# get random shifting numbers & scale them correctly
random_shifts = np.random.random(replicated_probabilities.shape)
random_shifts /= random_shifts.sum(axis=1)[:, np.newaxis]
# shift by numbers & find largest (by finding the smallest of the negative)
shifted_probabilities = random_shifts - replicated_probabilities
sampled_indices = np.argpartition(shifted_probabilities, sample_size, axis=1)[
:, :sample_size
]
return sampled_indices
def batch_generator(samples: Iterable[Any], batch_size: int) -> Iterable[Any]:
"""
Generate batches from samples.
Args:
samples (`Iterable[Any]`): Iterable of samples.
batch_size (`int`): Batch size.
Returns:
`Iterable[Any]`: Iterable of batches.
"""
batch = []
for sample in samples:
batch.append(sample)
if len(batch) == batch_size:
yield batch
batch = []
# leftover batch
if len(batch) > 0:
yield batch
|