File size: 2,710 Bytes
25d0bb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import os
import hashlib
from collections.abc import Iterator, Sequence
from multiprocessing import Pool
import nltk
import torch
from torch import nn
from transformers import PreTrainedModel
from .config import DSIRConfig
def _hash_buckets(text: str, num_buckets: int) -> int:
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % num_buckets
def get_ngram_count(tokens: Sequence[str], n: int, num_buckets: int):
counts = torch.zeros(num_buckets, dtype=torch.float32)
for w in tokens:
counts[_hash_buckets(w, num_buckets)] += 1
for i in range(2, n + 1):
for ngram in list(nltk.ngrams(tokens, i)):
ngram = " ".join(ngram)
counts[_hash_buckets(ngram, num_buckets)] += 1
return counts
# def kl_divergence(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
# # To avoid division by zero
# p = p + 1e-8
# q = q + 1e-8
# return (p * (p / q).log()).sum()
class DSIRModel(PreTrainedModel):
config_class = DSIRConfig
def __init__(self, config: DSIRConfig):
super().__init__(config)
self.prob_dist = nn.Parameter(torch.zeros(config.num_buckets, dtype=torch.float32))
self.proportions = None
self.raw_dist = None
self.log_diff_dist = None
def fit_raw_dataset(self, dataset: Iterator[Sequence[str]], num_proc: None | int = None):
num_proc = num_proc or os.cpu_count() or 1
with Pool(num_proc) as pool:
ngram_counts = pool.starmap(
get_ngram_count,
[(tokens, self.config.n, self.config.num_buckets) for tokens in dataset],
)
raw_dist = torch.stack(ngram_counts).sum(dim=0)
self.raw_dist = raw_dist / raw_dist.sum()
self.log_diff_dist = torch.log(self.raw_dist + 1e-8) - torch.log(self.prob_dist + 1e-8)
def compute_single_prob_dist(self, tokens: Sequence[str]) -> torch.Tensor:
ngram_count = get_ngram_count(tokens, self.config.n, self.config.num_buckets)
return ngram_count
def _normalize_prob_dist(self, tokens: Sequence[str]) -> torch.Tensor:
ngram_count = self.compute_single_prob_dist(tokens)
return ngram_count / ngram_count.sum()
def compute_importance_score(self, tokens: Sequence[str]) -> torch.Tensor:
prob_dists = self._normalize_prob_dist(tokens)
return prob_dists @ self.log_diff_dist
def forward(self, tokens: Sequence[Sequence[str]]) -> dict[str, torch.Tensor]:
prob_dists = [self._normalize_prob_dist(t) for t in tokens]
prob_dists = torch.stack(prob_dists)
weight = prob_dists @ self.log_diff_dist
return {"weight": weight, "prob_dists": prob_dists}
|