File size: 2,710 Bytes
25d0bb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import hashlib
from collections.abc import Iterator, Sequence
from multiprocessing import Pool

import nltk
import torch
from torch import nn
from transformers import PreTrainedModel

from .config import DSIRConfig


def _hash_buckets(text: str, num_buckets: int) -> int:
    return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % num_buckets


def get_ngram_count(tokens: Sequence[str], n: int, num_buckets: int):
    counts = torch.zeros(num_buckets, dtype=torch.float32)

    for w in tokens:
        counts[_hash_buckets(w, num_buckets)] += 1

    for i in range(2, n + 1):
        for ngram in list(nltk.ngrams(tokens, i)):
            ngram = " ".join(ngram)
            counts[_hash_buckets(ngram, num_buckets)] += 1

    return counts


# def kl_divergence(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
#     # To avoid division by zero
#     p = p + 1e-8
#     q = q + 1e-8
#     return (p * (p / q).log()).sum()


class DSIRModel(PreTrainedModel):
    config_class = DSIRConfig

    def __init__(self, config: DSIRConfig):
        super().__init__(config)

        self.prob_dist = nn.Parameter(torch.zeros(config.num_buckets, dtype=torch.float32))
        self.proportions = None
        self.raw_dist = None
        self.log_diff_dist = None

    def fit_raw_dataset(self, dataset: Iterator[Sequence[str]], num_proc: None | int = None):
        num_proc = num_proc or os.cpu_count() or 1
        with Pool(num_proc) as pool:
            ngram_counts = pool.starmap(
                get_ngram_count,
                [(tokens, self.config.n, self.config.num_buckets) for tokens in dataset],
            )
        raw_dist = torch.stack(ngram_counts).sum(dim=0)

        self.raw_dist = raw_dist / raw_dist.sum()
        self.log_diff_dist = torch.log(self.raw_dist + 1e-8) - torch.log(self.prob_dist + 1e-8)

    def compute_single_prob_dist(self, tokens: Sequence[str]) -> torch.Tensor:
        ngram_count = get_ngram_count(tokens, self.config.n, self.config.num_buckets)
        return ngram_count

    def _normalize_prob_dist(self, tokens: Sequence[str]) -> torch.Tensor:
        ngram_count = self.compute_single_prob_dist(tokens)
        return ngram_count / ngram_count.sum()

    def compute_importance_score(self, tokens: Sequence[str]) -> torch.Tensor:
        prob_dists = self._normalize_prob_dist(tokens)
        return prob_dists @ self.log_diff_dist

    def forward(self, tokens: Sequence[Sequence[str]]) -> dict[str, torch.Tensor]:
        prob_dists = [self._normalize_prob_dist(t) for t in tokens]
        prob_dists = torch.stack(prob_dists)
        weight = prob_dists @ self.log_diff_dist

        return {"weight": weight, "prob_dists": prob_dists}