|
import os |
|
import hashlib |
|
from collections.abc import Iterator, Sequence |
|
from multiprocessing import Pool |
|
|
|
import nltk |
|
import torch |
|
from torch import nn |
|
from transformers import PreTrainedModel |
|
|
|
from .config import DSIRConfig |
|
|
|
|
|
def _hash_buckets(text: str, num_buckets: int) -> int: |
|
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % num_buckets |
|
|
|
|
|
def get_ngram_count(tokens: Sequence[str], n: int, num_buckets: int): |
|
counts = torch.zeros(num_buckets, dtype=torch.float32) |
|
|
|
for w in tokens: |
|
counts[_hash_buckets(w, num_buckets)] += 1 |
|
|
|
for i in range(2, n + 1): |
|
for ngram in list(nltk.ngrams(tokens, i)): |
|
ngram = " ".join(ngram) |
|
counts[_hash_buckets(ngram, num_buckets)] += 1 |
|
|
|
return counts |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DSIRModel(PreTrainedModel): |
|
config_class = DSIRConfig |
|
|
|
def __init__(self, config: DSIRConfig): |
|
super().__init__(config) |
|
|
|
self.prob_dist = nn.Parameter(torch.zeros(config.num_buckets, dtype=torch.float32)) |
|
self.proportions = None |
|
self.raw_dist = None |
|
self.log_diff_dist = None |
|
|
|
def fit_raw_dataset(self, dataset: Iterator[Sequence[str]], num_proc: None | int = None): |
|
num_proc = num_proc or os.cpu_count() or 1 |
|
with Pool(num_proc) as pool: |
|
ngram_counts = pool.starmap( |
|
get_ngram_count, |
|
[(tokens, self.config.n, self.config.num_buckets) for tokens in dataset], |
|
) |
|
raw_dist = torch.stack(ngram_counts).sum(dim=0) |
|
|
|
self.raw_dist = raw_dist / raw_dist.sum() |
|
self.log_diff_dist = torch.log(self.raw_dist + 1e-8) - torch.log(self.prob_dist + 1e-8) |
|
|
|
def compute_single_prob_dist(self, tokens: Sequence[str]) -> torch.Tensor: |
|
ngram_count = get_ngram_count(tokens, self.config.n, self.config.num_buckets) |
|
return ngram_count |
|
|
|
def _normalize_prob_dist(self, tokens: Sequence[str]) -> torch.Tensor: |
|
ngram_count = self.compute_single_prob_dist(tokens) |
|
return ngram_count / ngram_count.sum() |
|
|
|
def compute_importance_score(self, tokens: Sequence[str]) -> torch.Tensor: |
|
prob_dists = self._normalize_prob_dist(tokens) |
|
return prob_dists @ self.log_diff_dist |
|
|
|
def forward(self, tokens: Sequence[Sequence[str]]) -> dict[str, torch.Tensor]: |
|
prob_dists = [self._normalize_prob_dist(t) for t in tokens] |
|
prob_dists = torch.stack(prob_dists) |
|
weight = prob_dists @ self.log_diff_dist |
|
|
|
return {"weight": weight, "prob_dists": prob_dists} |
|
|