Spaces:
Runtime error
Runtime error
Delete CHATTS/infer/api.py
Browse files- CHATTS/infer/api.py +0 -125
CHATTS/infer/api.py
DELETED
@@ -1,125 +0,0 @@
|
|
1 |
-
|
2 |
-
import torch
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
|
5 |
-
from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
|
6 |
-
|
7 |
-
def infer_code(
|
8 |
-
models,
|
9 |
-
text,
|
10 |
-
spk_emb = None,
|
11 |
-
top_P = 0.7,
|
12 |
-
top_K = 20,
|
13 |
-
temperature = 0.3,
|
14 |
-
repetition_penalty = 1.05,
|
15 |
-
max_new_token = 2048,
|
16 |
-
**kwargs
|
17 |
-
):
|
18 |
-
|
19 |
-
device = next(models['gpt'].parameters()).device
|
20 |
-
|
21 |
-
if not isinstance(text, list):
|
22 |
-
text = [text]
|
23 |
-
|
24 |
-
if not isinstance(temperature, list):
|
25 |
-
temperature = [temperature] * models['gpt'].num_vq
|
26 |
-
|
27 |
-
if spk_emb is not None:
|
28 |
-
text = [f'[Stts][spk_emb]{i}[uv_break][Ptts]' for i in text]
|
29 |
-
else:
|
30 |
-
text = [f'[Stts][empty_spk]{i}[uv_break][Ptts]' for i in text]
|
31 |
-
|
32 |
-
text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
|
33 |
-
input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq)
|
34 |
-
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
|
35 |
-
|
36 |
-
inputs = {
|
37 |
-
'input_ids': input_ids,
|
38 |
-
'text_mask': text_mask,
|
39 |
-
'attention_mask': text_token['attention_mask'],
|
40 |
-
}
|
41 |
-
|
42 |
-
emb = models['gpt'].get_emb(**inputs)
|
43 |
-
if spk_emb is not None:
|
44 |
-
emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \
|
45 |
-
F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12)
|
46 |
-
|
47 |
-
num_code = models['gpt'].emb_code[0].num_embeddings - 1
|
48 |
-
|
49 |
-
LogitsWarpers = []
|
50 |
-
if top_P is not None:
|
51 |
-
LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
|
52 |
-
if top_K is not None:
|
53 |
-
LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
|
54 |
-
|
55 |
-
LogitsProcessors = []
|
56 |
-
if repetition_penalty is not None and repetition_penalty != 1:
|
57 |
-
LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\
|
58 |
-
repetition_penalty, num_code, 16))
|
59 |
-
|
60 |
-
result = models['gpt'].generate(
|
61 |
-
emb, inputs['input_ids'],
|
62 |
-
temperature = torch.tensor(temperature, device=device),
|
63 |
-
attention_mask = inputs['attention_mask'],
|
64 |
-
LogitsWarpers = LogitsWarpers,
|
65 |
-
LogitsProcessors = LogitsProcessors,
|
66 |
-
eos_token = num_code,
|
67 |
-
max_new_token = max_new_token,
|
68 |
-
infer_text = False,
|
69 |
-
**kwargs
|
70 |
-
)
|
71 |
-
|
72 |
-
return result
|
73 |
-
|
74 |
-
|
75 |
-
def refine_text(
|
76 |
-
models,
|
77 |
-
text,
|
78 |
-
top_P = 0.7,
|
79 |
-
top_K = 20,
|
80 |
-
temperature = 0.7,
|
81 |
-
repetition_penalty = 1.0,
|
82 |
-
max_new_token = 384,
|
83 |
-
prompt = '',
|
84 |
-
**kwargs
|
85 |
-
):
|
86 |
-
|
87 |
-
device = next(models['gpt'].parameters()).device
|
88 |
-
|
89 |
-
if not isinstance(text, list):
|
90 |
-
text = [text]
|
91 |
-
|
92 |
-
assert len(text), 'text should not be empty'
|
93 |
-
|
94 |
-
text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
|
95 |
-
text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
|
96 |
-
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
|
97 |
-
|
98 |
-
inputs = {
|
99 |
-
'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq),
|
100 |
-
'text_mask': text_mask,
|
101 |
-
'attention_mask': text_token['attention_mask'],
|
102 |
-
}
|
103 |
-
|
104 |
-
LogitsWarpers = []
|
105 |
-
if top_P is not None:
|
106 |
-
LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
|
107 |
-
if top_K is not None:
|
108 |
-
LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
|
109 |
-
|
110 |
-
LogitsProcessors = []
|
111 |
-
if repetition_penalty is not None and repetition_penalty != 1:
|
112 |
-
LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16))
|
113 |
-
|
114 |
-
result = models['gpt'].generate(
|
115 |
-
models['gpt'].get_emb(**inputs), inputs['input_ids'],
|
116 |
-
temperature = torch.tensor([temperature,], device=device),
|
117 |
-
attention_mask = inputs['attention_mask'],
|
118 |
-
LogitsWarpers = LogitsWarpers,
|
119 |
-
LogitsProcessors = LogitsProcessors,
|
120 |
-
eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
|
121 |
-
max_new_token = max_new_token,
|
122 |
-
infer_text = True,
|
123 |
-
**kwargs
|
124 |
-
)
|
125 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|