Pijush2023 commited on
Commit
ff05a77
·
verified ·
1 Parent(s): 042fb34

Delete CHATTS/utils/infer_utils.py

Browse files
Files changed (1) hide show
  1. CHATTS/utils/infer_utils.py +0 -45
CHATTS/utils/infer_utils.py DELETED
@@ -1,45 +0,0 @@
1
-
2
- import torch
3
- import torch.nn.functional as F
4
-
5
-
6
- class CustomRepetitionPenaltyLogitsProcessorRepeat():
7
-
8
- def __init__(self, penalty: float, max_input_ids, past_window):
9
- if not isinstance(penalty, float) or not (penalty > 0):
10
- raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
11
-
12
- self.penalty = penalty
13
- self.max_input_ids = max_input_ids
14
- self.past_window = past_window
15
-
16
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
17
-
18
- input_ids = input_ids[:, -self.past_window:]
19
- freq = F.one_hot(input_ids, scores.size(1)).sum(1)
20
- freq[self.max_input_ids:] = 0
21
- alpha = self.penalty**freq
22
- scores = torch.where(scores < 0, scores*alpha, scores/alpha)
23
-
24
- return scores
25
-
26
- class CustomRepetitionPenaltyLogitsProcessor():
27
-
28
- def __init__(self, penalty: float, max_input_ids, past_window):
29
- if not isinstance(penalty, float) or not (penalty > 0):
30
- raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
31
-
32
- self.penalty = penalty
33
- self.max_input_ids = max_input_ids
34
- self.past_window = past_window
35
-
36
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
37
-
38
- input_ids = input_ids[:, -self.past_window:]
39
- score = torch.gather(scores, 1, input_ids)
40
- _score = score.detach().clone()
41
- score = torch.where(score < 0, score * self.penalty, score / self.penalty)
42
- score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids]
43
- scores.scatter_(1, input_ids, score)
44
-
45
- return scores