Pijush2023 commited on
Commit
3d83510
·
verified ·
1 Parent(s): e92b7f0

Delete CHATTS/infer/api.py

Browse files
Files changed (1) hide show
  1. CHATTS/infer/api.py +0 -125
CHATTS/infer/api.py DELETED
@@ -1,125 +0,0 @@
1
-
2
- import torch
3
- import torch.nn.functional as F
4
- from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
5
- from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
6
-
7
- def infer_code(
8
- models,
9
- text,
10
- spk_emb = None,
11
- top_P = 0.7,
12
- top_K = 20,
13
- temperature = 0.3,
14
- repetition_penalty = 1.05,
15
- max_new_token = 2048,
16
- **kwargs
17
- ):
18
-
19
- device = next(models['gpt'].parameters()).device
20
-
21
- if not isinstance(text, list):
22
- text = [text]
23
-
24
- if not isinstance(temperature, list):
25
- temperature = [temperature] * models['gpt'].num_vq
26
-
27
- if spk_emb is not None:
28
- text = [f'[Stts][spk_emb]{i}[uv_break][Ptts]' for i in text]
29
- else:
30
- text = [f'[Stts][empty_spk]{i}[uv_break][Ptts]' for i in text]
31
-
32
- text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
33
- input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq)
34
- text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
35
-
36
- inputs = {
37
- 'input_ids': input_ids,
38
- 'text_mask': text_mask,
39
- 'attention_mask': text_token['attention_mask'],
40
- }
41
-
42
- emb = models['gpt'].get_emb(**inputs)
43
- if spk_emb is not None:
44
- emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \
45
- F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12)
46
-
47
- num_code = models['gpt'].emb_code[0].num_embeddings - 1
48
-
49
- LogitsWarpers = []
50
- if top_P is not None:
51
- LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
52
- if top_K is not None:
53
- LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
54
-
55
- LogitsProcessors = []
56
- if repetition_penalty is not None and repetition_penalty != 1:
57
- LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\
58
- repetition_penalty, num_code, 16))
59
-
60
- result = models['gpt'].generate(
61
- emb, inputs['input_ids'],
62
- temperature = torch.tensor(temperature, device=device),
63
- attention_mask = inputs['attention_mask'],
64
- LogitsWarpers = LogitsWarpers,
65
- LogitsProcessors = LogitsProcessors,
66
- eos_token = num_code,
67
- max_new_token = max_new_token,
68
- infer_text = False,
69
- **kwargs
70
- )
71
-
72
- return result
73
-
74
-
75
- def refine_text(
76
- models,
77
- text,
78
- top_P = 0.7,
79
- top_K = 20,
80
- temperature = 0.7,
81
- repetition_penalty = 1.0,
82
- max_new_token = 384,
83
- prompt = '',
84
- **kwargs
85
- ):
86
-
87
- device = next(models['gpt'].parameters()).device
88
-
89
- if not isinstance(text, list):
90
- text = [text]
91
-
92
- assert len(text), 'text should not be empty'
93
-
94
- text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
95
- text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
96
- text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
97
-
98
- inputs = {
99
- 'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq),
100
- 'text_mask': text_mask,
101
- 'attention_mask': text_token['attention_mask'],
102
- }
103
-
104
- LogitsWarpers = []
105
- if top_P is not None:
106
- LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
107
- if top_K is not None:
108
- LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
109
-
110
- LogitsProcessors = []
111
- if repetition_penalty is not None and repetition_penalty != 1:
112
- LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16))
113
-
114
- result = models['gpt'].generate(
115
- models['gpt'].get_emb(**inputs), inputs['input_ids'],
116
- temperature = torch.tensor([temperature,], device=device),
117
- attention_mask = inputs['attention_mask'],
118
- LogitsWarpers = LogitsWarpers,
119
- LogitsProcessors = LogitsProcessors,
120
- eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
121
- max_new_token = max_new_token,
122
- infer_text = True,
123
- **kwargs
124
- )
125
- return result