minato-ryan commited on
Commit
25d0bb0
·
verified ·
1 Parent(s): 6aa2fb1

Upload model

Browse files
Files changed (3) hide show
  1. config.json +4 -0
  2. config.py +19 -0
  3. model.py +79 -0
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "DSIRModel"
4
  ],
 
 
 
 
5
  "laplace_smoothing": 0.0,
6
  "model_type": "dsir",
7
  "n": 2,
 
2
  "architectures": [
3
  "DSIRModel"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "config.DSIRConfig",
7
+ "AutoModel": "model.DSIRModel"
8
+ },
9
  "laplace_smoothing": 0.0,
10
  "model_type": "dsir",
11
  "n": 2,
config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class DSIRConfig(PretrainedConfig):
7
+ model_type = "dsir"
8
+ is_composition = False
9
+
10
+ def __init__(
11
+ self,
12
+ n: int = 2,
13
+ num_buckets: int = 10_000,
14
+ laplace_smoothing: float = 1e-4,
15
+ ):
16
+ super().__init__()
17
+ self.n = n
18
+ self.num_buckets = num_buckets
19
+ self.laplace_smoothing = laplace_smoothing
model.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import hashlib
3
+ from collections.abc import Iterator, Sequence
4
+ from multiprocessing import Pool
5
+
6
+ import nltk
7
+ import torch
8
+ from torch import nn
9
+ from transformers import PreTrainedModel
10
+
11
+ from .config import DSIRConfig
12
+
13
+
14
+ def _hash_buckets(text: str, num_buckets: int) -> int:
15
+ return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % num_buckets
16
+
17
+
18
+ def get_ngram_count(tokens: Sequence[str], n: int, num_buckets: int):
19
+ counts = torch.zeros(num_buckets, dtype=torch.float32)
20
+
21
+ for w in tokens:
22
+ counts[_hash_buckets(w, num_buckets)] += 1
23
+
24
+ for i in range(2, n + 1):
25
+ for ngram in list(nltk.ngrams(tokens, i)):
26
+ ngram = " ".join(ngram)
27
+ counts[_hash_buckets(ngram, num_buckets)] += 1
28
+
29
+ return counts
30
+
31
+
32
+ # def kl_divergence(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
33
+ # # To avoid division by zero
34
+ # p = p + 1e-8
35
+ # q = q + 1e-8
36
+ # return (p * (p / q).log()).sum()
37
+
38
+
39
+ class DSIRModel(PreTrainedModel):
40
+ config_class = DSIRConfig
41
+
42
+ def __init__(self, config: DSIRConfig):
43
+ super().__init__(config)
44
+
45
+ self.prob_dist = nn.Parameter(torch.zeros(config.num_buckets, dtype=torch.float32))
46
+ self.proportions = None
47
+ self.raw_dist = None
48
+ self.log_diff_dist = None
49
+
50
+ def fit_raw_dataset(self, dataset: Iterator[Sequence[str]], num_proc: None | int = None):
51
+ num_proc = num_proc or os.cpu_count() or 1
52
+ with Pool(num_proc) as pool:
53
+ ngram_counts = pool.starmap(
54
+ get_ngram_count,
55
+ [(tokens, self.config.n, self.config.num_buckets) for tokens in dataset],
56
+ )
57
+ raw_dist = torch.stack(ngram_counts).sum(dim=0)
58
+
59
+ self.raw_dist = raw_dist / raw_dist.sum()
60
+ self.log_diff_dist = torch.log(self.raw_dist + 1e-8) - torch.log(self.prob_dist + 1e-8)
61
+
62
+ def compute_single_prob_dist(self, tokens: Sequence[str]) -> torch.Tensor:
63
+ ngram_count = get_ngram_count(tokens, self.config.n, self.config.num_buckets)
64
+ return ngram_count
65
+
66
+ def _normalize_prob_dist(self, tokens: Sequence[str]) -> torch.Tensor:
67
+ ngram_count = self.compute_single_prob_dist(tokens)
68
+ return ngram_count / ngram_count.sum()
69
+
70
+ def compute_importance_score(self, tokens: Sequence[str]) -> torch.Tensor:
71
+ prob_dists = self._normalize_prob_dist(tokens)
72
+ return prob_dists @ self.log_diff_dist
73
+
74
+ def forward(self, tokens: Sequence[Sequence[str]]) -> dict[str, torch.Tensor]:
75
+ prob_dists = [self._normalize_prob_dist(t) for t in tokens]
76
+ prob_dists = torch.stack(prob_dists)
77
+ weight = prob_dists @ self.log_diff_dist
78
+
79
+ return {"weight": weight, "prob_dists": prob_dists}