Spaces:
Runtime error
Runtime error
Delete CHATTS/utils/infer_utils.py
Browse files- CHATTS/utils/infer_utils.py +0 -45
CHATTS/utils/infer_utils.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
|
2 |
-
import torch
|
3 |
-
import torch.nn.functional as F
|
4 |
-
|
5 |
-
|
6 |
-
class CustomRepetitionPenaltyLogitsProcessorRepeat():
|
7 |
-
|
8 |
-
def __init__(self, penalty: float, max_input_ids, past_window):
|
9 |
-
if not isinstance(penalty, float) or not (penalty > 0):
|
10 |
-
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
11 |
-
|
12 |
-
self.penalty = penalty
|
13 |
-
self.max_input_ids = max_input_ids
|
14 |
-
self.past_window = past_window
|
15 |
-
|
16 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
17 |
-
|
18 |
-
input_ids = input_ids[:, -self.past_window:]
|
19 |
-
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
|
20 |
-
freq[self.max_input_ids:] = 0
|
21 |
-
alpha = self.penalty**freq
|
22 |
-
scores = torch.where(scores < 0, scores*alpha, scores/alpha)
|
23 |
-
|
24 |
-
return scores
|
25 |
-
|
26 |
-
class CustomRepetitionPenaltyLogitsProcessor():
|
27 |
-
|
28 |
-
def __init__(self, penalty: float, max_input_ids, past_window):
|
29 |
-
if not isinstance(penalty, float) or not (penalty > 0):
|
30 |
-
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
31 |
-
|
32 |
-
self.penalty = penalty
|
33 |
-
self.max_input_ids = max_input_ids
|
34 |
-
self.past_window = past_window
|
35 |
-
|
36 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
37 |
-
|
38 |
-
input_ids = input_ids[:, -self.past_window:]
|
39 |
-
score = torch.gather(scores, 1, input_ids)
|
40 |
-
_score = score.detach().clone()
|
41 |
-
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
42 |
-
score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids]
|
43 |
-
scores.scatter_(1, input_ids, score)
|
44 |
-
|
45 |
-
return scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|