import os import hashlib from collections.abc import Iterator, Sequence from multiprocessing import Pool import nltk import torch from torch import nn from transformers import PreTrainedModel from .config import DSIRConfig def _hash_buckets(text: str, num_buckets: int) -> int: return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % num_buckets def get_ngram_count(tokens: Sequence[str], n: int, num_buckets: int): counts = torch.zeros(num_buckets, dtype=torch.float32) for w in tokens: counts[_hash_buckets(w, num_buckets)] += 1 for i in range(2, n + 1): for ngram in list(nltk.ngrams(tokens, i)): ngram = " ".join(ngram) counts[_hash_buckets(ngram, num_buckets)] += 1 return counts # def kl_divergence(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor: # # To avoid division by zero # p = p + 1e-8 # q = q + 1e-8 # return (p * (p / q).log()).sum() class DSIRModel(PreTrainedModel): config_class = DSIRConfig def __init__(self, config: DSIRConfig): super().__init__(config) self.prob_dist = nn.Parameter(torch.zeros(config.num_buckets, dtype=torch.float32)) self.proportions = None self.raw_dist = None self.log_diff_dist = None def fit_raw_dataset(self, dataset: Iterator[Sequence[str]], num_proc: None | int = None): num_proc = num_proc or os.cpu_count() or 1 with Pool(num_proc) as pool: ngram_counts = pool.starmap( get_ngram_count, [(tokens, self.config.n, self.config.num_buckets) for tokens in dataset], ) raw_dist = torch.stack(ngram_counts).sum(dim=0) self.raw_dist = raw_dist / raw_dist.sum() self.log_diff_dist = torch.log(self.raw_dist + 1e-8) - torch.log(self.prob_dist + 1e-8) def compute_single_prob_dist(self, tokens: Sequence[str]) -> torch.Tensor: ngram_count = get_ngram_count(tokens, self.config.n, self.config.num_buckets) return ngram_count def _normalize_prob_dist(self, tokens: Sequence[str]) -> torch.Tensor: ngram_count = self.compute_single_prob_dist(tokens) return ngram_count / ngram_count.sum() def compute_importance_score(self, tokens: Sequence[str]) -> torch.Tensor: prob_dists = self._normalize_prob_dist(tokens) return prob_dists @ self.log_diff_dist def forward(self, tokens: Sequence[Sequence[str]]) -> dict[str, torch.Tensor]: prob_dists = [self._normalize_prob_dist(t) for t in tokens] prob_dists = torch.stack(prob_dists) weight = prob_dists @ self.log_diff_dist return {"weight": weight, "prob_dists": prob_dists}