dsir-wiki-ja-5k / model.py
minato-ryan's picture
Upload model
25d0bb0 verified
import os
import hashlib
from collections.abc import Iterator, Sequence
from multiprocessing import Pool
import nltk
import torch
from torch import nn
from transformers import PreTrainedModel
from .config import DSIRConfig
def _hash_buckets(text: str, num_buckets: int) -> int:
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % num_buckets
def get_ngram_count(tokens: Sequence[str], n: int, num_buckets: int):
counts = torch.zeros(num_buckets, dtype=torch.float32)
for w in tokens:
counts[_hash_buckets(w, num_buckets)] += 1
for i in range(2, n + 1):
for ngram in list(nltk.ngrams(tokens, i)):
ngram = " ".join(ngram)
counts[_hash_buckets(ngram, num_buckets)] += 1
return counts
# def kl_divergence(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
# # To avoid division by zero
# p = p + 1e-8
# q = q + 1e-8
# return (p * (p / q).log()).sum()
class DSIRModel(PreTrainedModel):
config_class = DSIRConfig
def __init__(self, config: DSIRConfig):
super().__init__(config)
self.prob_dist = nn.Parameter(torch.zeros(config.num_buckets, dtype=torch.float32))
self.proportions = None
self.raw_dist = None
self.log_diff_dist = None
def fit_raw_dataset(self, dataset: Iterator[Sequence[str]], num_proc: None | int = None):
num_proc = num_proc or os.cpu_count() or 1
with Pool(num_proc) as pool:
ngram_counts = pool.starmap(
get_ngram_count,
[(tokens, self.config.n, self.config.num_buckets) for tokens in dataset],
)
raw_dist = torch.stack(ngram_counts).sum(dim=0)
self.raw_dist = raw_dist / raw_dist.sum()
self.log_diff_dist = torch.log(self.raw_dist + 1e-8) - torch.log(self.prob_dist + 1e-8)
def compute_single_prob_dist(self, tokens: Sequence[str]) -> torch.Tensor:
ngram_count = get_ngram_count(tokens, self.config.n, self.config.num_buckets)
return ngram_count
def _normalize_prob_dist(self, tokens: Sequence[str]) -> torch.Tensor:
ngram_count = self.compute_single_prob_dist(tokens)
return ngram_count / ngram_count.sum()
def compute_importance_score(self, tokens: Sequence[str]) -> torch.Tensor:
prob_dists = self._normalize_prob_dist(tokens)
return prob_dists @ self.log_diff_dist
def forward(self, tokens: Sequence[Sequence[str]]) -> dict[str, torch.Tensor]:
prob_dists = [self._normalize_prob_dist(t) for t in tokens]
prob_dists = torch.stack(prob_dists)
weight = prob_dists @ self.log_diff_dist
return {"weight": weight, "prob_dists": prob_dists}