Ailyth commited on
Commit
32524cc
·
1 Parent(s): 2f92c5d

0316-183902

Browse files
Files changed (49) hide show
  1. AR/__pycache__/__init__.cpython-310.pyc +0 -0
  2. AR/models/__pycache__/__init__.cpython-310.pyc +0 -0
  3. AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc +0 -0
  4. AR/models/__pycache__/t2s_model.cpython-310.pyc +0 -0
  5. AR/models/__pycache__/utils.cpython-310.pyc +0 -0
  6. AR/models/t2s_lightning_module.py +2 -2
  7. AR/models/t2s_model.py +380 -9
  8. AR/models/t2s_model_batch_only.py +483 -0
  9. AR/models/utils.py +7 -7
  10. AR/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  11. AR/modules/__pycache__/activation.cpython-310.pyc +0 -0
  12. AR/modules/__pycache__/embedding.cpython-310.pyc +0 -0
  13. AR/modules/__pycache__/lr_schedulers.cpython-310.pyc +0 -0
  14. AR/modules/__pycache__/optim.cpython-310.pyc +0 -0
  15. AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc +0 -0
  16. AR/modules/__pycache__/scaling.cpython-310.pyc +0 -0
  17. AR/modules/__pycache__/transformer.cpython-310.pyc +0 -0
  18. {configs → GPT_SoVITS/configs}/s1.yaml +0 -0
  19. {configs → GPT_SoVITS/configs}/s1big.yaml +0 -0
  20. {configs → GPT_SoVITS/configs}/s1big2.yaml +0 -0
  21. {configs → GPT_SoVITS/configs}/s1longer.yaml +0 -0
  22. {configs → GPT_SoVITS/configs}/s1mq.yaml +0 -0
  23. {configs → GPT_SoVITS/configs}/s2.json +0 -0
  24. {configs → GPT_SoVITS/configs}/train.yaml +0 -0
  25. GPT_SoVITS/configs/tts_infer.yaml +16 -0
  26. TTS_infer_pack/TTS.py +848 -0
  27. TTS_infer_pack/TextPreprocessor.py +209 -0
  28. TTS_infer_pack/__init__.py +1 -0
  29. TTS_infer_pack/__pycache__/TTS.cpython-310.pyc +0 -0
  30. TTS_infer_pack/__pycache__/TextPreprocessor.cpython-310.pyc +0 -0
  31. TTS_infer_pack/__pycache__/__init__.cpython-310.pyc +0 -0
  32. TTS_infer_pack/__pycache__/text_segmentation_method.cpython-310.pyc +0 -0
  33. TTS_infer_pack/text_segmentation_method.py +152 -0
  34. app.py +134 -541
  35. feature_extractor/__pycache__/__init__.cpython-310.pyc +0 -0
  36. feature_extractor/__pycache__/cnhubert.cpython-310.pyc +0 -0
  37. feature_extractor/__pycache__/whisper_enc.cpython-310.pyc +0 -0
  38. feature_extractor/cnhubert.py +8 -5
  39. module/__pycache__/__init__.cpython-310.pyc +0 -0
  40. module/__pycache__/attentions.cpython-310.pyc +0 -0
  41. module/__pycache__/commons.cpython-310.pyc +0 -0
  42. module/__pycache__/core_vq.cpython-310.pyc +0 -0
  43. module/__pycache__/mel_processing.cpython-310.pyc +0 -0
  44. module/__pycache__/models.cpython-310.pyc +0 -0
  45. module/__pycache__/modules.cpython-310.pyc +0 -0
  46. module/__pycache__/mrte_model.cpython-310.pyc +0 -0
  47. module/__pycache__/quantize.cpython-310.pyc +0 -0
  48. module/__pycache__/transforms.cpython-310.pyc +0 -0
  49. module/models.py +58 -5
AR/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/AR/__pycache__/__init__.cpython-310.pyc and b/AR/__pycache__/__init__.cpython-310.pyc differ
 
AR/models/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/AR/models/__pycache__/__init__.cpython-310.pyc and b/AR/models/__pycache__/__init__.cpython-310.pyc differ
 
AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc CHANGED
Binary files a/AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc and b/AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc differ
 
AR/models/__pycache__/t2s_model.cpython-310.pyc CHANGED
Binary files a/AR/models/__pycache__/t2s_model.cpython-310.pyc and b/AR/models/__pycache__/t2s_model.cpython-310.pyc differ
 
AR/models/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/AR/models/__pycache__/utils.cpython-310.pyc and b/AR/models/__pycache__/utils.cpython-310.pyc differ
 
AR/models/t2s_lightning_module.py CHANGED
@@ -13,11 +13,11 @@ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
13
  from AR.modules.optim import ScaledAdam
14
 
15
  class Text2SemanticLightningModule(LightningModule):
16
- def __init__(self, config, output_dir, is_train=True):
17
  super().__init__()
18
  self.config = config
19
  self.top_k = 3
20
- self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
21
  pretrained_s1 = config.get("pretrained_s1")
22
  if pretrained_s1 and is_train:
23
  # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
 
13
  from AR.modules.optim import ScaledAdam
14
 
15
  class Text2SemanticLightningModule(LightningModule):
16
+ def __init__(self, config, output_dir, is_train=True, flash_attn_enabled:bool = False):
17
  super().__init__()
18
  self.config = config
19
  self.top_k = 3
20
+ self.model = Text2SemanticDecoder(config=config, top_k=self.top_k,flash_attn_enabled=flash_attn_enabled)
21
  pretrained_s1 = config.get("pretrained_s1")
22
  if pretrained_s1 and is_train:
23
  # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
AR/models/t2s_model.py CHANGED
@@ -1,5 +1,9 @@
1
  # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
  # reference: https://github.com/lifeiteng/vall-e
 
 
 
 
3
  import torch
4
  from tqdm import tqdm
5
 
@@ -35,8 +39,144 @@ default_config = {
35
  }
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  class Text2SemanticDecoder(nn.Module):
39
- def __init__(self, config, norm_first=False, top_k=3):
40
  super(Text2SemanticDecoder, self).__init__()
41
  self.model_dim = config["model"]["hidden_dim"]
42
  self.embedding_dim = config["model"]["embedding_dim"]
@@ -88,6 +228,47 @@ class Text2SemanticDecoder(nn.Module):
88
  multidim_average="global",
89
  ignore_index=self.EOS,
90
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
93
  x = self.ar_text_embedding(x)
@@ -321,7 +502,161 @@ class Text2SemanticDecoder(nn.Module):
321
  # 错位
322
  return targets[:, :-1], targets[:, 1:]
323
 
324
- def infer_panel(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  self,
326
  x, #####全部文本token
327
  x_lens,
@@ -386,7 +721,9 @@ class Text2SemanticDecoder(nn.Module):
386
  x.device
387
  )
388
 
389
-
 
 
390
  for idx in tqdm(range(1500)):
391
 
392
  xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
@@ -397,17 +734,45 @@ class Text2SemanticDecoder(nn.Module):
397
  if(idx==0):###第一次跑不能EOS否则没有了
398
  logits = logits[:, :-1] ###刨除1024终止符号的概率
399
  samples = sample(
400
- logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
401
- )[0].unsqueeze(0)
402
  # 本次生成的 semantic_ids 和之前的 y 构成新的 y
403
  # print(samples.shape)#[1,1]#第一个1是bs
404
  y = torch.concat([y, samples], dim=1)
405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
407
  print("use early stop num:", early_stop_num)
408
  stop = True
409
-
410
- if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
411
  # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
412
  stop = True
413
  if stop:
@@ -443,6 +808,12 @@ class Text2SemanticDecoder(nn.Module):
443
  xy_attn_mask = torch.zeros(
444
  (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
445
  )
 
 
 
 
 
 
446
  if ref_free:
447
- return y[:, :-1], 0
448
- return y[:, :-1], idx-1
 
1
  # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
  # reference: https://github.com/lifeiteng/vall-e
3
+ import os, sys
4
+ now_dir = os.getcwd()
5
+ sys.path.append(now_dir)
6
+ from typing import List
7
  import torch
8
  from tqdm import tqdm
9
 
 
39
  }
40
 
41
 
42
+ @torch.jit.script
43
+ class T2SMLP:
44
+ def __init__(self, w1, b1, w2, b2):
45
+ self.w1 = w1
46
+ self.b1 = b1
47
+ self.w2 = w2
48
+ self.b2 = b2
49
+
50
+ def forward(self, x):
51
+ x = F.relu(F.linear(x, self.w1, self.b1))
52
+ x = F.linear(x, self.w2, self.b2)
53
+ return x
54
+
55
+
56
+ @torch.jit.script
57
+ class T2SBlock:
58
+ def __init__(
59
+ self,
60
+ num_heads,
61
+ hidden_dim: int,
62
+ mlp: T2SMLP,
63
+ qkv_w,
64
+ qkv_b,
65
+ out_w,
66
+ out_b,
67
+ norm_w1,
68
+ norm_b1,
69
+ norm_eps1,
70
+ norm_w2,
71
+ norm_b2,
72
+ norm_eps2,
73
+ ):
74
+ self.num_heads = num_heads
75
+ self.mlp = mlp
76
+ self.hidden_dim: int = hidden_dim
77
+ self.qkv_w = qkv_w
78
+ self.qkv_b = qkv_b
79
+ self.out_w = out_w
80
+ self.out_b = out_b
81
+ self.norm_w1 = norm_w1
82
+ self.norm_b1 = norm_b1
83
+ self.norm_eps1 = norm_eps1
84
+ self.norm_w2 = norm_w2
85
+ self.norm_b2 = norm_b2
86
+ self.norm_eps2 = norm_eps2
87
+
88
+ def process_prompt(self, x, attn_mask : torch.Tensor):
89
+ q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
90
+
91
+ batch_size = q.shape[0]
92
+ q_len = q.shape[1]
93
+ kv_len = k.shape[1]
94
+
95
+ k_cache = k
96
+ v_cache = v
97
+
98
+ q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
99
+ k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
100
+ v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
101
+
102
+ attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
103
+
104
+ attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
105
+ attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
106
+ attn = F.linear(attn, self.out_w, self.out_b)
107
+
108
+ x = F.layer_norm(
109
+ x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
110
+ )
111
+ x = F.layer_norm(
112
+ x + self.mlp.forward(x),
113
+ [self.hidden_dim],
114
+ self.norm_w2,
115
+ self.norm_b2,
116
+ self.norm_eps2,
117
+ )
118
+ return x, k_cache, v_cache
119
+
120
+ def decode_next_token(self, x, k_cache, v_cache):
121
+ q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
122
+
123
+ k_cache = torch.cat([k_cache, k], dim=1)
124
+ v_cache = torch.cat([v_cache, v], dim=1)
125
+
126
+ batch_size = q.shape[0]
127
+ q_len = q.shape[1]
128
+ kv_len = k_cache.shape[1]
129
+
130
+ q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
131
+ k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
132
+ v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
133
+
134
+
135
+ attn = F.scaled_dot_product_attention(q, k, v)
136
+
137
+ attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
138
+ attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
139
+ attn = F.linear(attn, self.out_w, self.out_b)
140
+
141
+ x = F.layer_norm(
142
+ x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
143
+ )
144
+ x = F.layer_norm(
145
+ x + self.mlp.forward(x),
146
+ [self.hidden_dim],
147
+ self.norm_w2,
148
+ self.norm_b2,
149
+ self.norm_eps2,
150
+ )
151
+ return x, k_cache, v_cache
152
+
153
+
154
+ @torch.jit.script
155
+ class T2STransformer:
156
+ def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
157
+ self.num_blocks : int = num_blocks
158
+ self.blocks = blocks
159
+
160
+ def process_prompt(
161
+ self, x, attn_mask : torch.Tensor):
162
+ k_cache : List[torch.Tensor] = []
163
+ v_cache : List[torch.Tensor] = []
164
+ for i in range(self.num_blocks):
165
+ x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask)
166
+ k_cache.append(k_cache_)
167
+ v_cache.append(v_cache_)
168
+ return x, k_cache, v_cache
169
+
170
+ def decode_next_token(
171
+ self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
172
+ ):
173
+ for i in range(self.num_blocks):
174
+ x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
175
+ return x, k_cache, v_cache
176
+
177
+
178
  class Text2SemanticDecoder(nn.Module):
179
+ def __init__(self, config, norm_first=False, top_k=3, flash_attn_enabled:bool=False):
180
  super(Text2SemanticDecoder, self).__init__()
181
  self.model_dim = config["model"]["hidden_dim"]
182
  self.embedding_dim = config["model"]["embedding_dim"]
 
228
  multidim_average="global",
229
  ignore_index=self.EOS,
230
  )
231
+
232
+ self.enable_flash_attn(flash_attn_enabled)
233
+
234
+ def enable_flash_attn(self, enable:bool=True):
235
+
236
+ if not enable:
237
+ print("Not Using Flash Attention")
238
+ self.infer_panel = self.infer_panel_batch_only
239
+ else:
240
+ self.infer_panel = self.infer_panel_batch_infer_with_flash_attn
241
+ print("Using Flash Attention")
242
+ blocks = []
243
+
244
+ for i in range(self.num_layers):
245
+ layer = self.h.layers[i]
246
+ t2smlp = T2SMLP(
247
+ layer.linear1.weight,
248
+ layer.linear1.bias,
249
+ layer.linear2.weight,
250
+ layer.linear2.bias
251
+ )
252
+
253
+ block = T2SBlock(
254
+ self.num_head,
255
+ self.model_dim,
256
+ t2smlp,
257
+ layer.self_attn.in_proj_weight,
258
+ layer.self_attn.in_proj_bias,
259
+ layer.self_attn.out_proj.weight,
260
+ layer.self_attn.out_proj.bias,
261
+ layer.norm1.weight,
262
+ layer.norm1.bias,
263
+ layer.norm1.eps,
264
+ layer.norm2.weight,
265
+ layer.norm2.bias,
266
+ layer.norm2.eps
267
+ )
268
+
269
+ blocks.append(block)
270
+
271
+ self.t2s_transformer = T2STransformer(self.num_layers, blocks)
272
 
273
  def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
274
  x = self.ar_text_embedding(x)
 
502
  # 错位
503
  return targets[:, :-1], targets[:, 1:]
504
 
505
+ def infer_panel_batch_infer_with_flash_attn(
506
+ self,
507
+ x, #####全部文本token
508
+ x_lens,
509
+ prompts, ####参考音频token
510
+ bert_feature,
511
+ top_k: int = -100,
512
+ top_p: int = 100,
513
+ early_stop_num: int = -1,
514
+ temperature: float = 1.0,
515
+ ):
516
+
517
+ bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
518
+ x = self.ar_text_embedding(x)
519
+ x = x + bert_feature
520
+ x = self.ar_text_position(x)
521
+
522
+ # AR Decoder
523
+ y = prompts
524
+
525
+ x_len = x.shape[1]
526
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
527
+ stop = False
528
+ # print(1111111,self.num_layers)
529
+
530
+ k_cache = None
531
+ v_cache = None
532
+ ################### first step ##########################
533
+ if y is not None:
534
+ y_emb = self.ar_audio_embedding(y)
535
+ y_len = y_emb.shape[1]
536
+ prefix_len = y.shape[1]
537
+ y_pos = self.ar_audio_position(y_emb)
538
+ xy_pos = torch.concat([x, y_pos], dim=1)
539
+ ref_free = False
540
+ else:
541
+ y_emb = None
542
+ y_len = 0
543
+ prefix_len = 0
544
+ y_pos = None
545
+ xy_pos = x
546
+ y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
547
+ ref_free = True
548
+
549
+
550
+ ##### create mask #####
551
+ bsz = x.shape[0]
552
+ src_len = x_len + y_len
553
+ y_lens = torch.LongTensor([y_len]*bsz).to(x.device)
554
+ y_mask = make_pad_mask(y_lens)
555
+ x_mask = make_pad_mask(x_lens)
556
+
557
+ # (bsz, x_len + y_len)
558
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
559
+
560
+ x_mask = F.pad(
561
+ x_attn_mask,
562
+ (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
563
+ value=True,
564
+ )
565
+ y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
566
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
567
+ (x_len, 0),
568
+ value=False,
569
+ )
570
+
571
+ xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device)
572
+ # xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1)
573
+ xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len)
574
+ xy_attn_mask = xy_mask.logical_or(xy_padding_mask)
575
+ xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
576
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
577
+ xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
578
+
579
+ ###### decode #####
580
+ y_list = [None]*y.shape[0]
581
+ batch_idx_map = list(range(y.shape[0]))
582
+ idx_list = [None]*y.shape[0]
583
+ for idx in tqdm(range(1500)):
584
+ if idx == 0:
585
+ xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask)
586
+ else:
587
+ xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
588
+
589
+ logits = self.ar_predict_layer(
590
+ xy_dec[:, -1]
591
+ )
592
+
593
+ if idx == 0:
594
+ xy_attn_mask = None
595
+ logits = logits[:, :-1]
596
+
597
+ samples = sample(
598
+ logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
599
+ )[0]
600
+
601
+ y = torch.concat([y, samples], dim=1)
602
+
603
+ ####### 移除batch中已经生成完毕的序列,进一步优化计算量
604
+ reserved_idx_of_batch_for_y = None
605
+ if (self.EOS in samples[:, 0]) or \
606
+ (self.EOS in torch.argmax(logits, dim=-1)): ###如果生成到EOS,则停止
607
+ l = samples[:, 0]==self.EOS
608
+ removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
609
+ reserved_idx_of_batch_for_y = torch.where(l==False)[0]
610
+ # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
611
+ for i in removed_idx_of_batch_for_y:
612
+ batch_index = batch_idx_map[i]
613
+ idx_list[batch_index] = idx - 1
614
+ y_list[batch_index] = y[i, :-1]
615
+
616
+ batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
617
+
618
+ # 只保留batch中未生成完毕的序列
619
+ if reserved_idx_of_batch_for_y is not None:
620
+ # index = torch.LongTensor(batch_idx_map).to(y.device)
621
+ y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
622
+ if k_cache is not None :
623
+ for i in range(len(k_cache)):
624
+ k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
625
+ v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
626
+
627
+
628
+ if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
629
+ print("use early stop num:", early_stop_num)
630
+ stop = True
631
+ for i, batch_index in enumerate(batch_idx_map):
632
+ batch_index = batch_idx_map[i]
633
+ idx_list[batch_index] = idx
634
+ y_list[batch_index] = y[i, :-1]
635
+
636
+ if not (None in idx_list):
637
+ stop = True
638
+
639
+ if stop:
640
+ if y.shape[1]==0:
641
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
642
+ print("bad zero prediction")
643
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
644
+ break
645
+
646
+ ####################### update next step ###################################
647
+ y_emb = self.ar_audio_embedding(y[:, -1:])
648
+ xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
649
+
650
+ if (None in idx_list):
651
+ for i in range(x.shape[0]):
652
+ if idx_list[i] is None:
653
+ idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
654
+
655
+ if ref_free:
656
+ return y_list, [0]*x.shape[0]
657
+ return y_list, idx_list
658
+
659
+ def infer_panel_batch_only(
660
  self,
661
  x, #####全部文本token
662
  x_lens,
 
721
  x.device
722
  )
723
 
724
+ y_list = [None]*y.shape[0]
725
+ batch_idx_map = list(range(y.shape[0]))
726
+ idx_list = [None]*y.shape[0]
727
  for idx in tqdm(range(1500)):
728
 
729
  xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
 
734
  if(idx==0):###第一次跑不能EOS否则没有了
735
  logits = logits[:, :-1] ###刨除1024终止符号的概率
736
  samples = sample(
737
+ logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
738
+ )[0]
739
  # 本次生成的 semantic_ids 和之前的 y 构成新的 y
740
  # print(samples.shape)#[1,1]#第一个1是bs
741
  y = torch.concat([y, samples], dim=1)
742
 
743
+ # 移除已经生成完毕的序列
744
+ reserved_idx_of_batch_for_y = None
745
+ if (self.EOS in torch.argmax(logits, dim=-1)) or \
746
+ (self.EOS in samples[:, 0]): ###如果生成到EOS,则停止
747
+ l = samples[:, 0]==self.EOS
748
+ removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
749
+ reserved_idx_of_batch_for_y = torch.where(l==False)[0]
750
+ # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
751
+ for i in removed_idx_of_batch_for_y:
752
+ batch_index = batch_idx_map[i]
753
+ idx_list[batch_index] = idx - 1
754
+ y_list[batch_index] = y[i, :-1]
755
+
756
+ batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
757
+
758
+ # 只保留未生成完毕的序列
759
+ if reserved_idx_of_batch_for_y is not None:
760
+ # index = torch.LongTensor(batch_idx_map).to(y.device)
761
+ y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
762
+ if cache["y_emb"] is not None:
763
+ cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y)
764
+ if cache["k"] is not None:
765
+ for i in range(self.num_layers):
766
+ # 因为kv转置了,所以batch dim是1
767
+ cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y)
768
+ cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y)
769
+
770
+
771
  if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
772
  print("use early stop num:", early_stop_num)
773
  stop = True
774
+
775
+ if not (None in idx_list):
776
  # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
777
  stop = True
778
  if stop:
 
808
  xy_attn_mask = torch.zeros(
809
  (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
810
  )
811
+
812
+ if (None in idx_list):
813
+ for i in range(x.shape[0]):
814
+ if idx_list[i] is None:
815
+ idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
816
+
817
  if ref_free:
818
+ return y_list, [0]*x.shape[0]
819
+ return y_list, idx_list
AR/models/t2s_model_batch_only.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ from AR.models.utils import make_pad_mask
6
+ from AR.models.utils import (
7
+ topk_sampling,
8
+ sample,
9
+ logits_to_probs,
10
+ multinomial_sample_one_no_sync,
11
+ dpo_loss,
12
+ make_reject_y,
13
+ get_batch_logps
14
+ )
15
+ from AR.modules.embedding import SinePositionalEmbedding
16
+ from AR.modules.embedding import TokenEmbedding
17
+ from AR.modules.transformer import LayerNorm
18
+ from AR.modules.transformer import TransformerEncoder
19
+ from AR.modules.transformer import TransformerEncoderLayer
20
+ from torch import nn
21
+ from torch.nn import functional as F
22
+ from torchmetrics.classification import MulticlassAccuracy
23
+
24
+ default_config = {
25
+ "embedding_dim": 512,
26
+ "hidden_dim": 512,
27
+ "num_head": 8,
28
+ "num_layers": 12,
29
+ "num_codebook": 8,
30
+ "p_dropout": 0.0,
31
+ "vocab_size": 1024 + 1,
32
+ "phoneme_vocab_size": 512,
33
+ "EOS": 1024,
34
+ }
35
+
36
+
37
+ class Text2SemanticDecoder(nn.Module):
38
+ def __init__(self, config, norm_first=False, top_k=3):
39
+ super(Text2SemanticDecoder, self).__init__()
40
+ self.model_dim = config["model"]["hidden_dim"]
41
+ self.embedding_dim = config["model"]["embedding_dim"]
42
+ self.num_head = config["model"]["head"]
43
+ self.num_layers = config["model"]["n_layer"]
44
+ self.norm_first = norm_first
45
+ self.vocab_size = config["model"]["vocab_size"]
46
+ self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
47
+ self.p_dropout = config["model"]["dropout"]
48
+ self.EOS = config["model"]["EOS"]
49
+ self.norm_first = norm_first
50
+ assert self.EOS == self.vocab_size - 1
51
+ # should be same as num of kmeans bin
52
+ # assert self.EOS == 1024
53
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
54
+ self.ar_text_embedding = TokenEmbedding(
55
+ self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
56
+ )
57
+ self.ar_text_position = SinePositionalEmbedding(
58
+ self.embedding_dim, dropout=0.1, scale=False, alpha=True
59
+ )
60
+ self.ar_audio_embedding = TokenEmbedding(
61
+ self.embedding_dim, self.vocab_size, self.p_dropout
62
+ )
63
+ self.ar_audio_position = SinePositionalEmbedding(
64
+ self.embedding_dim, dropout=0.1, scale=False, alpha=True
65
+ )
66
+
67
+ self.h = TransformerEncoder(
68
+ TransformerEncoderLayer(
69
+ d_model=self.model_dim,
70
+ nhead=self.num_head,
71
+ dim_feedforward=self.model_dim * 4,
72
+ dropout=0.1,
73
+ batch_first=True,
74
+ norm_first=norm_first,
75
+ ),
76
+ num_layers=self.num_layers,
77
+ norm=LayerNorm(self.model_dim) if norm_first else None,
78
+ )
79
+
80
+ self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
81
+ self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
82
+
83
+ self.ar_accuracy_metric = MulticlassAccuracy(
84
+ self.vocab_size,
85
+ top_k=top_k,
86
+ average="micro",
87
+ multidim_average="global",
88
+ ignore_index=self.EOS,
89
+ )
90
+
91
+ def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
92
+ x = self.ar_text_embedding(x)
93
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
94
+ x = self.ar_text_position(x)
95
+ x_mask = make_pad_mask(x_lens)
96
+
97
+ y_mask = make_pad_mask(y_lens)
98
+ y_mask_int = y_mask.type(torch.int64)
99
+ codes = y.type(torch.int64) * (1 - y_mask_int)
100
+
101
+ # Training
102
+ # AR Decoder
103
+ y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
104
+ x_len = x_lens.max()
105
+ y_len = y_lens.max()
106
+ y_emb = self.ar_audio_embedding(y)
107
+ y_pos = self.ar_audio_position(y_emb)
108
+
109
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
110
+
111
+ ar_xy_padding_mask = xy_padding_mask
112
+
113
+ x_attn_mask = F.pad(
114
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
115
+ (0, y_len),
116
+ value=True,
117
+ )
118
+
119
+ y_attn_mask = F.pad(
120
+ torch.triu(
121
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
122
+ diagonal=1,
123
+ ),
124
+ (x_len, 0),
125
+ value=False,
126
+ )
127
+
128
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
129
+ bsz, src_len = x.shape[0], x_len + y_len
130
+ _xy_padding_mask = (
131
+ ar_xy_padding_mask.view(bsz, 1, 1, src_len)
132
+ .expand(-1, self.num_head, -1, -1)
133
+ .reshape(bsz * self.num_head, 1, src_len)
134
+ )
135
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
136
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
137
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
138
+ xy_attn_mask = new_attn_mask
139
+ # x 和完整的 y 一次性输入模型
140
+ xy_pos = torch.concat([x, y_pos], dim=1)
141
+
142
+ return xy_pos, xy_attn_mask, targets
143
+
144
+ def forward(self, x, x_lens, y, y_lens, bert_feature):
145
+ """
146
+ x: phoneme_ids
147
+ y: semantic_ids
148
+ """
149
+
150
+ reject_y, reject_y_lens = make_reject_y(y, y_lens)
151
+
152
+ xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
153
+
154
+ xy_dec, _ = self.h(
155
+ (xy_pos, None),
156
+ mask=xy_attn_mask,
157
+ )
158
+ x_len = x_lens.max()
159
+ logits = self.ar_predict_layer(xy_dec[:, x_len:])
160
+
161
+ ###### DPO #############
162
+ reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
163
+
164
+ reject_xy_dec, _ = self.h(
165
+ (reject_xy_pos, None),
166
+ mask=reject_xy_attn_mask,
167
+ )
168
+ x_len = x_lens.max()
169
+ reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:])
170
+
171
+ # loss
172
+ # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
173
+
174
+ loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum")
175
+ acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item()
176
+
177
+ A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
178
+ loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
179
+
180
+ loss = loss_1 + loss_2
181
+
182
+ return loss, acc
183
+
184
+ def forward_old(self, x, x_lens, y, y_lens, bert_feature):
185
+ """
186
+ x: phoneme_ids
187
+ y: semantic_ids
188
+ """
189
+ x = self.ar_text_embedding(x)
190
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
191
+ x = self.ar_text_position(x)
192
+ x_mask = make_pad_mask(x_lens)
193
+
194
+ y_mask = make_pad_mask(y_lens)
195
+ y_mask_int = y_mask.type(torch.int64)
196
+ codes = y.type(torch.int64) * (1 - y_mask_int)
197
+
198
+ # Training
199
+ # AR Decoder
200
+ y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
201
+ x_len = x_lens.max()
202
+ y_len = y_lens.max()
203
+ y_emb = self.ar_audio_embedding(y)
204
+ y_pos = self.ar_audio_position(y_emb)
205
+
206
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
207
+ ar_xy_padding_mask = xy_padding_mask
208
+
209
+ x_attn_mask = F.pad(
210
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
211
+ (0, y_len),
212
+ value=True,
213
+ )
214
+ y_attn_mask = F.pad(
215
+ torch.triu(
216
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
217
+ diagonal=1,
218
+ ),
219
+ (x_len, 0),
220
+ value=False,
221
+ )
222
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
223
+ bsz, src_len = x.shape[0], x_len + y_len
224
+ _xy_padding_mask = (
225
+ ar_xy_padding_mask.view(bsz, 1, 1, src_len)
226
+ .expand(-1, self.num_head, -1, -1)
227
+ .reshape(bsz * self.num_head, 1, src_len)
228
+ )
229
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
230
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
231
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
232
+ xy_attn_mask = new_attn_mask
233
+ # x 和完整的 y 一次性输入模型
234
+ xy_pos = torch.concat([x, y_pos], dim=1)
235
+ xy_dec, _ = self.h(
236
+ (xy_pos, None),
237
+ mask=xy_attn_mask,
238
+ )
239
+ logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
240
+ # loss
241
+ # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
242
+ loss = F.cross_entropy(logits, targets, reduction="sum")
243
+ acc = self.ar_accuracy_metric(logits.detach(), targets).item()
244
+ return loss, acc
245
+
246
+ # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
247
+ def infer(
248
+ self,
249
+ x,
250
+ x_lens,
251
+ prompts,
252
+ bert_feature,
253
+ top_k: int = -100,
254
+ early_stop_num: int = -1,
255
+ temperature: float = 1.0,
256
+ ):
257
+ x = self.ar_text_embedding(x)
258
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
259
+ x = self.ar_text_position(x)
260
+
261
+ # AR Decoder
262
+ y = prompts
263
+ prefix_len = y.shape[1]
264
+ x_len = x.shape[1]
265
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
266
+ stop = False
267
+ for _ in tqdm(range(1500)):
268
+ y_emb = self.ar_audio_embedding(y)
269
+ y_pos = self.ar_audio_position(y_emb)
270
+ # x 和逐渐增长的 y 一起输入给模型
271
+ xy_pos = torch.concat([x, y_pos], dim=1)
272
+ y_len = y.shape[1]
273
+ x_attn_mask_pad = F.pad(
274
+ x_attn_mask,
275
+ (0, y_len),
276
+ value=True,
277
+ )
278
+ y_attn_mask = F.pad(
279
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
280
+ (x_len, 0),
281
+ value=False,
282
+ )
283
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
284
+ y.device
285
+ )
286
+
287
+ xy_dec, _ = self.h(
288
+ (xy_pos, None),
289
+ mask=xy_attn_mask,
290
+ )
291
+ logits = self.ar_predict_layer(xy_dec[:, -1])
292
+ samples = topk_sampling(
293
+ logits, top_k=top_k, top_p=1.0, temperature=temperature
294
+ )
295
+
296
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
297
+ print("use early stop num:", early_stop_num)
298
+ stop = True
299
+
300
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
301
+ # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
302
+ stop = True
303
+ if stop:
304
+ if prompts.shape[1] == y.shape[1]:
305
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
306
+ print("bad zero prediction")
307
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
308
+ break
309
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
310
+ # print(samples.shape)#[1,1]#第一个1是bs
311
+ # import os
312
+ # os._exit(2333)
313
+ y = torch.concat([y, samples], dim=1)
314
+ return y
315
+
316
+ def pad_y_eos(self, y, y_mask_int, eos_id):
317
+ targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
318
+ y_mask_int, (0, 1), value=1
319
+ )
320
+ # 错位
321
+ return targets[:, :-1], targets[:, 1:]
322
+
323
+ def infer_panel(
324
+ self,
325
+ x, #####全部文本token
326
+ x_lens,
327
+ prompts, ####参考音频token
328
+ bert_feature,
329
+ top_k: int = -100,
330
+ top_p: int = 100,
331
+ early_stop_num: int = -1,
332
+ temperature: float = 1.0,
333
+ ):
334
+ x = self.ar_text_embedding(x)
335
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
336
+ x = self.ar_text_position(x)
337
+
338
+ # AR Decoder
339
+ y = prompts
340
+
341
+ x_len = x.shape[1]
342
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
343
+ stop = False
344
+ # print(1111111,self.num_layers)
345
+ cache = {
346
+ "all_stage": self.num_layers,
347
+ "k": [None] * self.num_layers, ###根据配置自己手写
348
+ "v": [None] * self.num_layers,
349
+ # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
350
+ "y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行
351
+ # "logits":None,###原版就已经只对结尾求再拼接了,不用管
352
+ # "xy_dec":None,###不需要,本来只需要最后一个做logits
353
+ "first_infer": 1,
354
+ "stage": 0,
355
+ }
356
+ ################### first step ##########################
357
+ if y is not None:
358
+ y_emb = self.ar_audio_embedding(y)
359
+ y_len = y_emb.shape[1]
360
+ prefix_len = y.shape[1]
361
+ y_pos = self.ar_audio_position(y_emb)
362
+ xy_pos = torch.concat([x, y_pos], dim=1)
363
+ cache["y_emb"] = y_emb
364
+ ref_free = False
365
+ else:
366
+ y_emb = None
367
+ y_len = 0
368
+ prefix_len = 0
369
+ y_pos = None
370
+ xy_pos = x
371
+ y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
372
+ ref_free = True
373
+
374
+ x_attn_mask_pad = F.pad(
375
+ x_attn_mask,
376
+ (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
377
+ value=True,
378
+ )
379
+ y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
380
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
381
+ (x_len, 0),
382
+ value=False,
383
+ )
384
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
385
+ x.device
386
+ )
387
+
388
+ y_list = [None]*y.shape[0]
389
+ batch_idx_map = list(range(y.shape[0]))
390
+ idx_list = [None]*y.shape[0]
391
+ for idx in tqdm(range(1500)):
392
+
393
+ xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
394
+ logits = self.ar_predict_layer(
395
+ xy_dec[:, -1]
396
+ ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
397
+ # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
398
+ if(idx==0):###第一次跑不能EOS否则没有了
399
+ logits = logits[:, :-1] ###刨除1024终止符号的概率
400
+ samples = sample(
401
+ logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
402
+ )[0]
403
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
404
+ # print(samples.shape)#[1,1]#第一个1是bs
405
+ y = torch.concat([y, samples], dim=1)
406
+
407
+ # 移除已经生成完毕的序列
408
+ reserved_idx_of_batch_for_y = None
409
+ if (self.EOS in torch.argmax(logits, dim=-1)) or \
410
+ (self.EOS in samples[:, 0]): ###如果生成到EOS,则停止
411
+ l = samples[:, 0]==self.EOS
412
+ removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
413
+ reserved_idx_of_batch_for_y = torch.where(l==False)[0]
414
+ # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
415
+ for i in removed_idx_of_batch_for_y:
416
+ batch_index = batch_idx_map[i]
417
+ idx_list[batch_index] = idx - 1
418
+ y_list[batch_index] = y[i, :-1]
419
+
420
+ batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
421
+
422
+ # 只保留未生成完毕的序列
423
+ if reserved_idx_of_batch_for_y is not None:
424
+ # index = torch.LongTensor(batch_idx_map).to(y.device)
425
+ y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
426
+ if cache["y_emb"] is not None:
427
+ cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y)
428
+ if cache["k"] is not None:
429
+ for i in range(self.num_layers):
430
+ # 因为kv转置了,所以batch dim是1
431
+ cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y)
432
+ cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y)
433
+
434
+
435
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
436
+ print("use early stop num:", early_stop_num)
437
+ stop = True
438
+
439
+ if not (None in idx_list):
440
+ # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
441
+ stop = True
442
+ if stop:
443
+ # if prompts.shape[1] == y.shape[1]:
444
+ # y = torch.concat([y, torch.zeros_like(samples)], dim=1)
445
+ # print("bad zero prediction")
446
+ if y.shape[1]==0:
447
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
448
+ print("bad zero prediction")
449
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
450
+ break
451
+
452
+ ####################### update next step ###################################
453
+ cache["first_infer"] = 0
454
+ if cache["y_emb"] is not None:
455
+ y_emb = torch.cat(
456
+ [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1
457
+ )
458
+ cache["y_emb"] = y_emb
459
+ y_pos = self.ar_audio_position(y_emb)
460
+ xy_pos = y_pos[:, -1:]
461
+ else:
462
+ y_emb = self.ar_audio_embedding(y[:, -1:])
463
+ cache["y_emb"] = y_emb
464
+ y_pos = self.ar_audio_position(y_emb)
465
+ xy_pos = y_pos
466
+ y_len = y_pos.shape[1]
467
+
468
+ ###最右边一列(是错的)
469
+ # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
470
+ # xy_attn_mask[:,-1]=False
471
+ ###最下面一行(是对的)
472
+ xy_attn_mask = torch.zeros(
473
+ (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
474
+ )
475
+
476
+ if (None in idx_list):
477
+ for i in range(x.shape[0]):
478
+ if idx_list[i] is None:
479
+ idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
480
+
481
+ if ref_free:
482
+ return y_list, [0]*x.shape[0]
483
+ return y_list, idx_list
AR/models/utils.py CHANGED
@@ -115,17 +115,17 @@ def logits_to_probs(
115
  top_p: Optional[int] = None,
116
  repetition_penalty: float = 1.0,
117
  ):
118
- if previous_tokens is not None:
119
- previous_tokens = previous_tokens.squeeze()
120
  # print(logits.shape,previous_tokens.shape)
121
  # pdb.set_trace()
122
  if previous_tokens is not None and repetition_penalty != 1.0:
123
  previous_tokens = previous_tokens.long()
124
- score = torch.gather(logits, dim=0, index=previous_tokens)
125
  score = torch.where(
126
  score < 0, score * repetition_penalty, score / repetition_penalty
127
  )
128
- logits.scatter_(dim=0, index=previous_tokens, src=score)
129
 
130
  if top_p is not None and top_p < 1.0:
131
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
@@ -133,9 +133,9 @@ def logits_to_probs(
133
  torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
134
  )
135
  sorted_indices_to_remove = cum_probs > top_p
136
- sorted_indices_to_remove[0] = False # keep at least one option
137
  indices_to_remove = sorted_indices_to_remove.scatter(
138
- dim=0, index=sorted_indices, src=sorted_indices_to_remove
139
  )
140
  logits = logits.masked_fill(indices_to_remove, -float("Inf"))
141
 
@@ -143,7 +143,7 @@ def logits_to_probs(
143
 
144
  if top_k is not None:
145
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
146
- pivot = v.select(-1, -1).unsqueeze(-1)
147
  logits = torch.where(logits < pivot, -float("Inf"), logits)
148
 
149
  probs = torch.nn.functional.softmax(logits, dim=-1)
 
115
  top_p: Optional[int] = None,
116
  repetition_penalty: float = 1.0,
117
  ):
118
+ # if previous_tokens is not None:
119
+ # previous_tokens = previous_tokens.squeeze()
120
  # print(logits.shape,previous_tokens.shape)
121
  # pdb.set_trace()
122
  if previous_tokens is not None and repetition_penalty != 1.0:
123
  previous_tokens = previous_tokens.long()
124
+ score = torch.gather(logits, dim=1, index=previous_tokens)
125
  score = torch.where(
126
  score < 0, score * repetition_penalty, score / repetition_penalty
127
  )
128
+ logits.scatter_(dim=1, index=previous_tokens, src=score)
129
 
130
  if top_p is not None and top_p < 1.0:
131
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
 
133
  torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
134
  )
135
  sorted_indices_to_remove = cum_probs > top_p
136
+ sorted_indices_to_remove[:, 0] = False # keep at least one option
137
  indices_to_remove = sorted_indices_to_remove.scatter(
138
+ dim=1, index=sorted_indices, src=sorted_indices_to_remove
139
  )
140
  logits = logits.masked_fill(indices_to_remove, -float("Inf"))
141
 
 
143
 
144
  if top_k is not None:
145
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
146
+ pivot = v[: , -1].unsqueeze(-1)
147
  logits = torch.where(logits < pivot, -float("Inf"), logits)
148
 
149
  probs = torch.nn.functional.softmax(logits, dim=-1)
AR/modules/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/__init__.cpython-310.pyc and b/AR/modules/__pycache__/__init__.cpython-310.pyc differ
 
AR/modules/__pycache__/activation.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/activation.cpython-310.pyc and b/AR/modules/__pycache__/activation.cpython-310.pyc differ
 
AR/modules/__pycache__/embedding.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/embedding.cpython-310.pyc and b/AR/modules/__pycache__/embedding.cpython-310.pyc differ
 
AR/modules/__pycache__/lr_schedulers.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/lr_schedulers.cpython-310.pyc and b/AR/modules/__pycache__/lr_schedulers.cpython-310.pyc differ
 
AR/modules/__pycache__/optim.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/optim.cpython-310.pyc and b/AR/modules/__pycache__/optim.cpython-310.pyc differ
 
AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc and b/AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc differ
 
AR/modules/__pycache__/scaling.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/scaling.cpython-310.pyc and b/AR/modules/__pycache__/scaling.cpython-310.pyc differ
 
AR/modules/__pycache__/transformer.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/transformer.cpython-310.pyc and b/AR/modules/__pycache__/transformer.cpython-310.pyc differ
 
{configs → GPT_SoVITS/configs}/s1.yaml RENAMED
File without changes
{configs → GPT_SoVITS/configs}/s1big.yaml RENAMED
File without changes
{configs → GPT_SoVITS/configs}/s1big2.yaml RENAMED
File without changes
{configs → GPT_SoVITS/configs}/s1longer.yaml RENAMED
File without changes
{configs → GPT_SoVITS/configs}/s1mq.yaml RENAMED
File without changes
{configs → GPT_SoVITS/configs}/s2.json RENAMED
File without changes
{configs → GPT_SoVITS/configs}/train.yaml RENAMED
File without changes
GPT_SoVITS/configs/tts_infer.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ custom:
2
+ bert_base_path: pretrained_models/chinese-roberta-wwm-ext-large
3
+ cnhuhbert_base_path: pretrained_models/chinese-hubert-base
4
+ device: cpu
5
+ flash_attn_enabled: true
6
+ is_half: false
7
+ t2s_weights_path: /content/TTS_OWN/MODELS/22/22.ckpt
8
+ vits_weights_path: /content/TTS_OWN/MODELS/22/22.pth
9
+ default:
10
+ bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
11
+ cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
12
+ device: cpu
13
+ flash_attn_enabled: true
14
+ is_half: false
15
+ t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
16
+ vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
TTS_infer_pack/TTS.py ADDED
@@ -0,0 +1,848 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ import math
3
+ import os, sys
4
+ import random
5
+ import traceback
6
+ now_dir = os.getcwd()
7
+ sys.path.append(now_dir)
8
+ import ffmpeg
9
+ import os
10
+ from typing import Generator, List, Union
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import yaml
15
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
16
+ from timeit import default_timer as timer
17
+
18
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
19
+ from feature_extractor.cnhubert import CNHubert
20
+ from module.models import SynthesizerTrn
21
+ import librosa
22
+ from time import time as ttime
23
+ #from tools.i18n.i18n import I18nAuto
24
+ from my_utils import load_audio
25
+ from module.mel_processing import spectrogram_torch
26
+ from TTS_infer_pack.text_segmentation_method import splits
27
+ from TTS_infer_pack.TextPreprocessor import TextPreprocessor
28
+ #i18n = I18nAuto()
29
+ c1=''
30
+
31
+ # configs/tts_infer.yaml
32
+ """
33
+ default:
34
+ device: cpu
35
+ is_half: false
36
+ bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
37
+ cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
38
+ t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
39
+ vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
40
+ flash_attn_enabled: true
41
+
42
+ custom:
43
+ device: cuda
44
+ is_half: true
45
+ bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
46
+ cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
47
+ t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
48
+ vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
49
+ flash_attn_enabled: true
50
+
51
+
52
+ """
53
+
54
+ # def set_seed(seed):
55
+ # random.seed(seed)
56
+ # os.environ['PYTHONHASHSEED'] = str(seed)
57
+ # np.random.seed(seed)
58
+ # torch.manual_seed(seed)
59
+ # torch.cuda.manual_seed(seed)
60
+ # torch.cuda.manual_seed_all(seed)
61
+ # torch.backends.cudnn.deterministic = True
62
+ # torch.backends.cudnn.benchmark = False
63
+ # torch.backends.cudnn.enabled = True
64
+ # set_seed(1234)
65
+
66
+ class TTS_Config:
67
+ default_configs={
68
+ "device": "cpu",
69
+ "is_half": False,
70
+ "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
71
+ "vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
72
+ "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
73
+ "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
74
+ "flash_attn_enabled": True
75
+ }
76
+ configs:dict = None
77
+ def __init__(self, configs: Union[dict, str]=None):
78
+
79
+ # 设置默认配置文件路径
80
+ configs_base_path:str = "GPT_SoVITS/configs/"
81
+ os.makedirs(configs_base_path, exist_ok=True)
82
+ self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
83
+
84
+ if configs in ["", None]:
85
+ if not os.path.exists(self.configs_path):
86
+ self.save_configs()
87
+ print(f"Create default config file at {self.configs_path}")
88
+ configs:dict = {"default": deepcopy(self.default_configs)}
89
+
90
+ if isinstance(configs, str):
91
+ self.configs_path = configs
92
+ configs:dict = self._load_configs(self.configs_path)
93
+
94
+ assert isinstance(configs, dict)
95
+ default_configs:dict = configs.get("default", None)
96
+ if default_configs is not None:
97
+ self.default_configs = default_configs
98
+
99
+ self.configs:dict = configs.get("custom", deepcopy(self.default_configs))
100
+
101
+
102
+ self.device = self.configs.get("device", torch.device("cpu"))
103
+ self.is_half = self.configs.get("is_half", False)
104
+ self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True)
105
+ self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
106
+ self.vits_weights_path = self.configs.get("vits_weights_path", None)
107
+ self.bert_base_path = self.configs.get("bert_base_path", None)
108
+ self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
109
+
110
+
111
+ if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
112
+ self.t2s_weights_path = self.default_configs['t2s_weights_path']
113
+ print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
114
+ if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
115
+ self.vits_weights_path = self.default_configs['vits_weights_path']
116
+ print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
117
+ if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
118
+ self.bert_base_path = self.default_configs['bert_base_path']
119
+ print(f"fall back to default bert_base_path: {self.bert_base_path}")
120
+ if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
121
+ self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path']
122
+ print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
123
+ self.update_configs()
124
+
125
+
126
+ self.max_sec = None
127
+ self.hz:int = 50
128
+ self.semantic_frame_rate:str = "25hz"
129
+ self.segment_size:int = 20480
130
+ self.filter_length:int = 2048
131
+ self.sampling_rate:int = 32000
132
+ self.hop_length:int = 640
133
+ self.win_length:int = 2048
134
+ self.n_speakers:int = 300
135
+
136
+ self.langauges:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
137
+ # print(self)
138
+
139
+ def _load_configs(self, configs_path: str)->dict:
140
+ with open(configs_path, 'r') as f:
141
+ configs = yaml.load(f, Loader=yaml.FullLoader)
142
+
143
+ return configs
144
+
145
+ def save_configs(self, configs_path:str=None)->None:
146
+ configs={
147
+ "default":self.default_configs,
148
+ }
149
+ if self.configs is not None:
150
+ configs["custom"] = self.update_configs()
151
+
152
+ if configs_path is None:
153
+ configs_path = self.configs_path
154
+ with open(configs_path, 'w') as f:
155
+ yaml.dump(configs, f)
156
+
157
+ def update_configs(self):
158
+ self.config = {
159
+ "device" : str(self.device),
160
+ "is_half" : self.is_half,
161
+ "t2s_weights_path" : self.t2s_weights_path,
162
+ "vits_weights_path" : self.vits_weights_path,
163
+ "bert_base_path" : self.bert_base_path,
164
+ "cnhuhbert_base_path": self.cnhuhbert_base_path,
165
+ "flash_attn_enabled" : self.flash_attn_enabled
166
+ }
167
+ return self.config
168
+
169
+ def __str__(self):
170
+ self.configs = self.update_configs()
171
+ string = "TTS Config".center(100, '-') + '\n'
172
+ for k, v in self.configs.items():
173
+ string += f"{str(k).ljust(20)}: {str(v)}\n"
174
+ string += "-" * 100 + '\n'
175
+ return string
176
+
177
+ def __repr__(self):
178
+ return self.__str__()
179
+
180
+
181
+ class TTS:
182
+ def __init__(self, configs: Union[dict, str, TTS_Config]):
183
+ if isinstance(configs, TTS_Config):
184
+ self.configs = configs
185
+ else:
186
+ self.configs:TTS_Config = TTS_Config(configs)
187
+
188
+ self.t2s_model:Text2SemanticLightningModule = None
189
+ self.vits_model:SynthesizerTrn = None
190
+ self.bert_tokenizer:AutoTokenizer = None
191
+ self.bert_model:AutoModelForMaskedLM = None
192
+ self.cnhuhbert_model:CNHubert = None
193
+
194
+ self._init_models()
195
+
196
+ self.text_preprocessor:TextPreprocessor = \
197
+ TextPreprocessor(self.bert_model,
198
+ self.bert_tokenizer,
199
+ self.configs.device)
200
+
201
+
202
+ self.prompt_cache:dict = {
203
+ "ref_audio_path":None,
204
+ "prompt_semantic":None,
205
+ "refer_spepc":None,
206
+ "prompt_text":None,
207
+ "prompt_lang":None,
208
+ "phones":None,
209
+ "bert_features":None,
210
+ "norm_text":None,
211
+ }
212
+
213
+
214
+ self.stop_flag:bool = False
215
+ self.precison:torch.dtype = torch.float16 if self.configs.is_half else torch.float32
216
+
217
+ def _init_models(self,):
218
+ self.init_t2s_weights(self.configs.t2s_weights_path)
219
+ self.init_vits_weights(self.configs.vits_weights_path)
220
+ self.init_bert_weights(self.configs.bert_base_path)
221
+ self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
222
+ # self.enable_half_precision(self.configs.is_half)
223
+
224
+
225
+
226
+ def init_cnhuhbert_weights(self, base_path: str):
227
+ print(f"Loading CNHuBERT weights from {base_path}")
228
+ self.cnhuhbert_model = CNHubert(base_path)
229
+ self.cnhuhbert_model=self.cnhuhbert_model.eval()
230
+ self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
231
+ if self.configs.is_half:
232
+ self.cnhuhbert_model = self.cnhuhbert_model.half()
233
+
234
+
235
+
236
+ def init_bert_weights(self, base_path: str):
237
+ print(f"Loading BERT weights from {base_path}")
238
+ self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
239
+ self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
240
+ self.bert_model=self.bert_model.eval()
241
+ self.bert_model = self.bert_model.to(self.configs.device)
242
+ if self.configs.is_half:
243
+ self.bert_model = self.bert_model.half()
244
+
245
+
246
+
247
+ def init_vits_weights(self, weights_path: str):
248
+
249
+ print(f"Loading VITS weights from {weights_path}")
250
+ self.configs.vits_weights_path = weights_path
251
+ self.configs.save_configs()
252
+ dict_s2 = torch.load(weights_path, map_location=self.configs.device)
253
+ hps = dict_s2["config"]
254
+ self.configs.filter_length = hps["data"]["filter_length"]
255
+ self.configs.segment_size = hps["train"]["segment_size"]
256
+ self.configs.sampling_rate = hps["data"]["sampling_rate"]
257
+ self.configs.hop_length = hps["data"]["hop_length"]
258
+ self.configs.win_length = hps["data"]["win_length"]
259
+ self.configs.n_speakers = hps["data"]["n_speakers"]
260
+ self.configs.semantic_frame_rate = "25hz"
261
+ kwargs = hps["model"]
262
+ vits_model = SynthesizerTrn(
263
+ self.configs.filter_length // 2 + 1,
264
+ self.configs.segment_size // self.configs.hop_length,
265
+ n_speakers=self.configs.n_speakers,
266
+ **kwargs
267
+ )
268
+ # if ("pretrained" not in weights_path):
269
+ if hasattr(vits_model, "enc_q"):
270
+ del vits_model.enc_q
271
+
272
+ vits_model = vits_model.to(self.configs.device)
273
+ vits_model = vits_model.eval()
274
+ vits_model.load_state_dict(dict_s2["weight"], strict=False)
275
+ self.vits_model = vits_model
276
+ if self.configs.is_half:
277
+ self.vits_model = self.vits_model.half()
278
+
279
+
280
+ def init_t2s_weights(self, weights_path: str):
281
+ print(f"Loading Text2Semantic weights from {weights_path}")
282
+ self.configs.t2s_weights_path = weights_path
283
+ self.configs.save_configs()
284
+ self.configs.hz = 50
285
+ dict_s1 = torch.load(weights_path, map_location=self.configs.device)
286
+ config = dict_s1["config"]
287
+ self.configs.max_sec = config["data"]["max_sec"]
288
+ t2s_model = Text2SemanticLightningModule(config, "****", is_train=False,
289
+ flash_attn_enabled=self.configs.flash_attn_enabled)
290
+ t2s_model.load_state_dict(dict_s1["weight"])
291
+ t2s_model = t2s_model.to(self.configs.device)
292
+ t2s_model = t2s_model.eval()
293
+ self.t2s_model = t2s_model
294
+ if self.configs.is_half:
295
+ self.t2s_model = self.t2s_model.half()
296
+
297
+ def enable_half_precision(self, enable: bool = True):
298
+ '''
299
+ To enable half precision for the TTS model.
300
+ Args:
301
+ enable: bool, whether to enable half precision.
302
+
303
+ '''
304
+ if self.configs.device == "cpu" and enable:
305
+ print("Half precision is not supported on CPU.")
306
+ return
307
+
308
+ self.configs.is_half = enable
309
+ self.precison = torch.float16 if enable else torch.float32
310
+ self.configs.save_configs()
311
+ if enable:
312
+ if self.t2s_model is not None:
313
+ self.t2s_model =self.t2s_model.half()
314
+ if self.vits_model is not None:
315
+ self.vits_model = self.vits_model.half()
316
+ if self.bert_model is not None:
317
+ self.bert_model =self.bert_model.half()
318
+ if self.cnhuhbert_model is not None:
319
+ self.cnhuhbert_model = self.cnhuhbert_model.half()
320
+ else:
321
+ if self.t2s_model is not None:
322
+ self.t2s_model = self.t2s_model.float()
323
+ if self.vits_model is not None:
324
+ self.vits_model = self.vits_model.float()
325
+ if self.bert_model is not None:
326
+ self.bert_model = self.bert_model.float()
327
+ if self.cnhuhbert_model is not None:
328
+ self.cnhuhbert_model = self.cnhuhbert_model.float()
329
+
330
+ def set_device(self, device: torch.device):
331
+ '''
332
+ To set the device for all models.
333
+ Args:
334
+ device: torch.device, the device to use for all models.
335
+ '''
336
+ self.configs.device = device
337
+ self.configs.save_configs()
338
+ if self.t2s_model is not None:
339
+ self.t2s_model = self.t2s_model.to(device)
340
+ if self.vits_model is not None:
341
+ self.vits_model = self.vits_model.to(device)
342
+ if self.bert_model is not None:
343
+ self.bert_model = self.bert_model.to(device)
344
+ if self.cnhuhbert_model is not None:
345
+ self.cnhuhbert_model = self.cnhuhbert_model.to(device)
346
+
347
+ def set_ref_audio(self, ref_audio_path:str):
348
+ '''
349
+ To set the reference audio for the TTS model,
350
+ including the prompt_semantic and refer_spepc.
351
+ Args:
352
+ ref_audio_path: str, the path of the reference audio.
353
+ '''
354
+ self._set_prompt_semantic(ref_audio_path)
355
+ self._set_ref_spepc(ref_audio_path)
356
+
357
+ def _set_ref_spepc(self, ref_audio_path):
358
+ audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
359
+ audio = torch.FloatTensor(audio)
360
+ audio_norm = audio
361
+ audio_norm = audio_norm.unsqueeze(0)
362
+ spec = spectrogram_torch(
363
+ audio_norm,
364
+ self.configs.filter_length,
365
+ self.configs.sampling_rate,
366
+ self.configs.hop_length,
367
+ self.configs.win_length,
368
+ center=False,
369
+ )
370
+ spec = spec.to(self.configs.device)
371
+ if self.configs.is_half:
372
+ spec = spec.half()
373
+ # self.refer_spepc = spec
374
+ self.prompt_cache["refer_spepc"] = spec
375
+
376
+
377
+ def _set_prompt_semantic(self, ref_wav_path:str):
378
+ zero_wav = np.zeros(
379
+ int(self.configs.sampling_rate * 0.3),
380
+ dtype=np.float16 if self.configs.is_half else np.float32,
381
+ )
382
+ with torch.no_grad():
383
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
384
+ if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
385
+ raise OSError("参考音频在3~10秒范围外,请更换!")
386
+ wav16k = torch.from_numpy(wav16k)
387
+ zero_wav_torch = torch.from_numpy(zero_wav)
388
+ wav16k = wav16k.to(self.configs.device)
389
+ zero_wav_torch = zero_wav_torch.to(self.configs.device)
390
+ if self.configs.is_half:
391
+ wav16k = wav16k.half()
392
+ zero_wav_torch = zero_wav_torch.half()
393
+
394
+ wav16k = torch.cat([wav16k, zero_wav_torch])
395
+ hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))[
396
+ "last_hidden_state"
397
+ ].transpose(
398
+ 1, 2
399
+ ) # .float()
400
+ codes = self.vits_model.extract_latent(hubert_feature)
401
+
402
+ prompt_semantic = codes[0, 0].to(self.configs.device)
403
+ self.prompt_cache["prompt_semantic"] = prompt_semantic
404
+
405
+ def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None):
406
+ seq = sequences[0]
407
+ ndim = seq.dim()
408
+ if axis < 0:
409
+ axis += ndim
410
+ dtype:torch.dtype = seq.dtype
411
+ pad_value = torch.tensor(pad_value, dtype=dtype)
412
+ seq_lengths = [seq.shape[axis] for seq in sequences]
413
+ if max_length is None:
414
+ max_length = max(seq_lengths)
415
+ else:
416
+ max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length
417
+
418
+ padded_sequences = []
419
+ for seq, length in zip(sequences, seq_lengths):
420
+ padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1)
421
+ padded_seq = torch.nn.functional.pad(seq, padding, value=pad_value)
422
+ padded_sequences.append(padded_seq)
423
+ batch = torch.stack(padded_sequences)
424
+ return batch
425
+
426
+ def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold:float=0.75, split_bucket:bool=True):
427
+
428
+ _data:list = []
429
+ index_and_len_list = []
430
+ for idx, item in enumerate(data):
431
+ norm_text_len = len(item["norm_text"])
432
+ index_and_len_list.append([idx, norm_text_len])
433
+
434
+ batch_index_list = []
435
+ if split_bucket:
436
+ index_and_len_list.sort(key=lambda x: x[1])
437
+ index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
438
+
439
+ batch_index_list_len = 0
440
+ pos = 0
441
+ while pos <index_and_len_list.shape[0]:
442
+ # batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
443
+ pos_end = min(pos+batch_size,index_and_len_list.shape[0])
444
+ while pos < pos_end:
445
+ batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
446
+ score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
447
+ if (score>=threshold) or (pos_end-pos==1):
448
+ batch_index=index_and_len_list[pos:pos_end, 0].tolist()
449
+ batch_index_list_len += len(batch_index)
450
+ batch_index_list.append(batch_index)
451
+ pos = pos_end
452
+ break
453
+ pos_end=pos_end-1
454
+
455
+ assert batch_index_list_len == len(data)
456
+
457
+ else:
458
+ for i in range(len(data)):
459
+ if i%batch_size == 0:
460
+ batch_index_list.append([])
461
+ batch_index_list[-1].append(i)
462
+
463
+
464
+ for batch_idx, index_list in enumerate(batch_index_list):
465
+ item_list = [data[idx] for idx in index_list]
466
+ phones_list = []
467
+ phones_len_list = []
468
+ # bert_features_list = []
469
+ all_phones_list = []
470
+ all_phones_len_list = []
471
+ all_bert_features_list = []
472
+ norm_text_batch = []
473
+ bert_max_len = 0
474
+ phones_max_len = 0
475
+ for item in item_list:
476
+ if prompt_data is not None:
477
+ all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
478
+ .to(dtype=self.precison)
479
+ all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
480
+ phones = torch.LongTensor(item["phones"])
481
+ # norm_text = prompt_data["norm_text"]+item["norm_text"]
482
+ else:
483
+ all_bert_features = item["bert_features"]\
484
+ .to(dtype=self.precison)
485
+ phones = torch.LongTensor(item["phones"])
486
+ all_phones = phones
487
+ # norm_text = item["norm_text"]
488
+
489
+ bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
490
+ phones_max_len = max(phones_max_len, phones.shape[-1])
491
+
492
+ phones_list.append(phones)
493
+ phones_len_list.append(phones.shape[-1])
494
+ all_phones_list.append(all_phones)
495
+ all_phones_len_list.append(all_phones.shape[-1])
496
+ all_bert_features_list.append(all_bert_features)
497
+ norm_text_batch.append(item["norm_text"])
498
+
499
+ phones_batch = phones_list
500
+ max_len = max(bert_max_len, phones_max_len)
501
+ # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
502
+ all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
503
+ # all_bert_features_batch = all_bert_features_list
504
+ all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=self.precison)
505
+ for idx, item in enumerate(all_bert_features_list):
506
+ all_bert_features_batch[idx, :, : item.shape[-1]] = item
507
+
508
+ batch = {
509
+ "phones": phones_batch,
510
+ "phones_len": torch.LongTensor(phones_len_list),
511
+ "all_phones": all_phones_batch,
512
+ "all_phones_len": torch.LongTensor(all_phones_len_list),
513
+ "all_bert_features": all_bert_features_batch,
514
+ "norm_text": norm_text_batch
515
+ }
516
+ _data.append(batch)
517
+
518
+ return _data, batch_index_list
519
+
520
+ def recovery_order(self, data:list, batch_index_list:list)->list:
521
+ '''
522
+ Recovery the order of the audio according to the batch_index_list.
523
+
524
+ Args:
525
+ data (List[list(np.ndarray)]): the out of order audio .
526
+ batch_index_list (List[list[int]]): the batch index list.
527
+
528
+ Returns:
529
+ list (List[np.ndarray]): the data in the original order.
530
+ '''
531
+ lenght = len(sum(batch_index_list, []))
532
+ _data = [None]*lenght
533
+ for i, index_list in enumerate(batch_index_list):
534
+ for j, index in enumerate(index_list):
535
+ _data[index] = data[i][j]
536
+ return _data
537
+
538
+ def stop(self,):
539
+ '''
540
+ Stop the inference process.
541
+ '''
542
+ self.stop_flag = True
543
+
544
+
545
+ def run(self, inputs:dict):
546
+ """
547
+ Text to speech inference.
548
+
549
+ Args:
550
+ inputs (dict):
551
+ {
552
+ "text": "", # str. text to be synthesized
553
+ "text_lang: "", # str. language of the text to be synthesized
554
+ "ref_audio_path": "", # str. reference audio path
555
+ "prompt_text": "", # str. prompt text for the reference audio
556
+ "prompt_lang": "", # str. language of the prompt text for the reference audio
557
+ "top_k": 5, # int. top k sampling
558
+ "top_p": 1, # float. top p sampling
559
+ "temperature": 1, # float. temperature for sampling
560
+ "text_split_method": "", # str. text split method, see text_segmentaion_method.py for details.
561
+ "batch_size": 1, # int. batch size for inference
562
+ "batch_threshold": 0.75, # float. threshold for batch splitting.
563
+ "split_bucket: True, # bool. whether to split the batch into multiple buckets.
564
+ "return_fragment": False, # bool. step by step return the audio fragment.
565
+ "speed_factor":1.0, # float. control the speed of the synthesized audio.
566
+ }
567
+ returns:
568
+ tulpe[int, np.ndarray]: sampling rate and audio data.
569
+ """
570
+ global c1
571
+ c1=timer()
572
+ ########## variables initialization ###########
573
+ self.stop_flag:bool = False
574
+ text:str = inputs.get("text", "")
575
+ text_lang:str = inputs.get("text_lang", "")
576
+ ref_audio_path:str = inputs.get("ref_audio_path", "")
577
+ prompt_text:str = inputs.get("prompt_text", "")
578
+ prompt_lang:str = inputs.get("prompt_lang", "")
579
+ top_k:int = inputs.get("top_k", 5)
580
+ top_p:float = inputs.get("top_p", 1)
581
+ temperature:float = inputs.get("temperature", 1)
582
+ text_split_method:str = inputs.get("text_split_method", "")
583
+ batch_size = inputs.get("batch_size", 1)
584
+ batch_threshold = inputs.get("batch_threshold", 0.75)
585
+ speed_factor = inputs.get("speed_factor", 1.0)
586
+ split_bucket = inputs.get("split_bucket", True)
587
+ volume = inputs.get("volume", 1.0)
588
+ return_fragment = inputs.get("return_fragment", False)
589
+
590
+ if return_fragment:
591
+ split_bucket = False
592
+ print("分段返回模式已开启")
593
+ if split_bucket:
594
+ split_bucket = False
595
+ print("分段返回模式不支持分桶处理,已自动关闭分桶处理")
596
+
597
+ if split_bucket:
598
+ print("分桶处理模式已开启")
599
+
600
+
601
+ no_prompt_text = False
602
+ if prompt_text in [None, ""]:
603
+ no_prompt_text = True
604
+
605
+ assert text_lang in self.configs.langauges
606
+ if not no_prompt_text:
607
+ assert prompt_lang in self.configs.langauges
608
+
609
+ if ref_audio_path in [None, ""] and \
610
+ ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spepc"] is None)):
611
+ raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
612
+
613
+
614
+ ###### setting reference audio and prompt text preprocessing ########
615
+ t0 = ttime()
616
+ if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]):
617
+ self.set_ref_audio(ref_audio_path)
618
+
619
+ if not no_prompt_text:
620
+ prompt_text = prompt_text.strip("\n")
621
+ if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_lang != "en" else "."
622
+ print("实际输入的参考文本:", prompt_text)
623
+ if self.prompt_cache["prompt_text"] != prompt_text:
624
+ self.prompt_cache["prompt_text"] = prompt_text
625
+ self.prompt_cache["prompt_lang"] = prompt_lang
626
+ phones, bert_features, norm_text = \
627
+ self.text_preprocessor.segment_and_extract_feature_for_text(
628
+ prompt_text,
629
+ prompt_lang)
630
+ self.prompt_cache["phones"] = phones
631
+ self.prompt_cache["bert_features"] = bert_features
632
+ self.prompt_cache["norm_text"] = norm_text
633
+
634
+
635
+ ###### text preprocessing ########
636
+ data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
637
+ if len(data) == 0:
638
+ yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
639
+ dtype=np.int16)
640
+ return
641
+
642
+ t1 = ttime()
643
+ data, batch_index_list = self.to_batch(data,
644
+ prompt_data=self.prompt_cache if not no_prompt_text else None,
645
+ batch_size=batch_size,
646
+ threshold=batch_threshold,
647
+ split_bucket=split_bucket
648
+ )
649
+ t2 = ttime()
650
+ try:
651
+ print("############ 推理 ############")
652
+ ###### inference ######
653
+ t_34 = 0.0
654
+ t_45 = 0.0
655
+ audio = []
656
+ for item in data:
657
+ t3 = ttime()
658
+ batch_phones = item["phones"]
659
+ batch_phones_len = item["phones_len"]
660
+ all_phoneme_ids = item["all_phones"]
661
+ all_phoneme_lens = item["all_phones_len"]
662
+ all_bert_features = item["all_bert_features"]
663
+ norm_text = item["norm_text"]
664
+
665
+ # batch_phones = batch_phones.to(self.configs.device)
666
+ batch_phones_len = batch_phones_len.to(self.configs.device)
667
+ all_phoneme_ids = all_phoneme_ids.to(self.configs.device)
668
+ all_phoneme_lens = all_phoneme_lens.to(self.configs.device)
669
+ all_bert_features = all_bert_features.to(self.configs.device)
670
+ if self.configs.is_half:
671
+ all_bert_features = all_bert_features.half()
672
+
673
+ print("前端处理后的文本(每句):", norm_text)
674
+ if no_prompt_text :
675
+ prompt = None
676
+ else:
677
+ prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device)
678
+
679
+ with torch.no_grad():
680
+ pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
681
+ all_phoneme_ids,
682
+ all_phoneme_lens,
683
+ prompt,
684
+ all_bert_features,
685
+ # prompt_phone_len=ph_offset,
686
+ top_k=top_k,
687
+ top_p=top_p,
688
+ temperature=temperature,
689
+ early_stop_num=self.configs.hz * self.configs.max_sec,
690
+ )
691
+ t4 = ttime()
692
+ t_34 += t4 - t3
693
+
694
+ refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\
695
+ .to(dtype=self.precison, device=self.configs.device)
696
+
697
+ batch_audio_fragment = []
698
+
699
+ # ## vits并行推理 method 1
700
+ # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
701
+ # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
702
+ # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
703
+ # max_len = 0
704
+ # for i in range(0, len(batch_phones)):
705
+ # max_len = max(max_len, batch_phones[i].shape[-1])
706
+ # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
707
+ # batch_phones = batch_phones.to(self.configs.device)
708
+ # batch_audio_fragment = (self.vits_model.batched_decode(
709
+ # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc
710
+ # ))
711
+
712
+ # ## vits并行推理 method 2
713
+ pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
714
+ upsample_rate = math.prod(self.vits_model.upsample_rates)
715
+ audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
716
+ audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
717
+ all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
718
+ _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
719
+ _batch_audio_fragment = (self.vits_model.decode(
720
+ all_pred_semantic, _batch_phones,refer_audio_spepc
721
+ ).detach()[0, 0, :])
722
+ audio_frag_end_idx.insert(0, 0)
723
+ batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
724
+
725
+
726
+ # ## vits串行推理
727
+ # for i, idx in enumerate(idx_list):
728
+ # phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
729
+ # _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
730
+ # audio_fragment =(self.vits_model.decode(
731
+ # _pred_semantic, phones, refer_audio_spepc
732
+ # ).detach()[0, 0, :])
733
+ # batch_audio_fragment.append(
734
+ # audio_fragment
735
+ # ) ###试试重建不带上prompt部分
736
+
737
+ t5 = ttime()
738
+ t_45 += t5 - t4
739
+ if return_fragment:
740
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
741
+ yield self.audio_postprocess([batch_audio_fragment],
742
+ self.configs.sampling_rate,
743
+ batch_index_list,
744
+ speed_factor,
745
+ split_bucket,volume)
746
+ else:
747
+ audio.append(batch_audio_fragment)
748
+
749
+ if self.stop_flag:
750
+ yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
751
+ dtype=np.int16)
752
+ return
753
+
754
+ if not return_fragment:
755
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
756
+ yield self.audio_postprocess(audio,
757
+ self.configs.sampling_rate,
758
+ batch_index_list,
759
+ speed_factor,
760
+ split_bucket,volume)
761
+ except Exception as e:
762
+ traceback.print_exc()
763
+ # 必须返回一个空音频, 否则会导致显存不释放。
764
+ yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
765
+ dtype=np.int16)
766
+ # 重置模型, 否则会导致显存释放不完全。
767
+ del self.t2s_model
768
+ del self.vits_model
769
+ self.t2s_model = None
770
+ self.vits_model = None
771
+ self.init_t2s_weights(self.configs.t2s_weights_path)
772
+ self.init_vits_weights(self.configs.vits_weights_path)
773
+ finally:
774
+ self.empty_cache()
775
+
776
+ def empty_cache(self):
777
+ try:
778
+ if str(self.configs.device) == "cuda":
779
+ torch.cuda.empty_cache()
780
+ elif str(self.configs.device) == "mps":
781
+ torch.mps.empty_cache()
782
+ except:
783
+ pass
784
+
785
+ def audio_postprocess(self,
786
+ audio:List[torch.Tensor],
787
+ sr:int,
788
+ batch_index_list:list=None,
789
+ speed_factor:float=1.0,
790
+ split_bucket:bool=True,
791
+ volume: float = 1.0)->tuple[int, np.ndarray]:
792
+ zero_wav = torch.zeros(
793
+ int(self.configs.sampling_rate * 0.3),
794
+ dtype=self.precison,
795
+ device=self.configs.device
796
+ )
797
+
798
+ for i, batch in enumerate(audio):
799
+ for j, audio_fragment in enumerate(batch):
800
+ max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
801
+ if max_audio>1: audio_fragment/=max_audio
802
+ audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
803
+ audio_fragment = audio_fragment * volume
804
+ audio[i][j] = audio_fragment.cpu().numpy()
805
+
806
+
807
+ if split_bucket:
808
+ audio = self.recovery_order(audio, batch_index_list)
809
+ else:
810
+ # audio = [item for batch in audio for item in batch]
811
+ audio = sum(audio, [])
812
+
813
+
814
+ audio = np.concatenate(audio, 0)
815
+ audio = (audio * 32768).astype(np.int16)
816
+
817
+ try:
818
+ if speed_factor != 1.0:
819
+ audio = speed_change(audio, speed=speed_factor, sr=int(sr))
820
+ except Exception as e:
821
+ print(f"Failed to change speed of audio: \n{e}")
822
+ c2=timer()
823
+ print(f'🆗TTS COMPLETE,{round(c2-c1,4)}s')
824
+ return sr, audio
825
+
826
+
827
+
828
+
829
+ def speed_change(input_audio:np.ndarray, speed:float, sr:int):
830
+ # 将 NumPy 数组转换为原始 PCM 流
831
+ raw_audio = input_audio.astype(np.int16).tobytes()
832
+
833
+ # 设置 ffmpeg 输入流
834
+ input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1)
835
+
836
+ # 变速处理
837
+ output_stream = input_stream.filter('atempo', speed)
838
+
839
+ # 输出流到管道
840
+ out, _ = (
841
+ output_stream.output('pipe:', format='s16le', acodec='pcm_s16le')
842
+ .run(input=raw_audio, capture_stdout=True, capture_stderr=True)
843
+ )
844
+
845
+ # 将管道输出解码为 NumPy 数组
846
+ processed_audio = np.frombuffer(out, np.int16)
847
+
848
+ return processed_audio
TTS_infer_pack/TextPreprocessor.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os, sys
3
+
4
+ from tqdm import tqdm
5
+ now_dir = os.getcwd()
6
+ sys.path.append(now_dir)
7
+
8
+ import re
9
+ import torch
10
+ import LangSegment
11
+ from typing import Dict, List, Tuple
12
+ from text.cleaner import clean_text
13
+ from text import cleaned_text_to_sequence
14
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
15
+ from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
16
+
17
+ #from tools.i18n.i18n import I18nAuto
18
+ #i18n = I18nAuto()
19
+
20
+ def get_first(text:str) -> str:
21
+ pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
22
+ text = re.split(pattern, text)[0].strip()
23
+ return text
24
+
25
+ def merge_short_text_in_array(texts:str, threshold:int) -> list:
26
+ if (len(texts)) < 2:
27
+ return texts
28
+ result = []
29
+ text = ""
30
+ for ele in texts:
31
+ text += ele
32
+ if len(text) >= threshold:
33
+ result.append(text)
34
+ text = ""
35
+ if (len(text) > 0):
36
+ if len(result) == 0:
37
+ result.append(text)
38
+ else:
39
+ result[len(result) - 1] += text
40
+ return result
41
+
42
+
43
+
44
+
45
+
46
+
47
+ class TextPreprocessor:
48
+ def __init__(self, bert_model:AutoModelForMaskedLM,
49
+ tokenizer:AutoTokenizer, device:torch.device):
50
+ self.bert_model = bert_model
51
+ self.tokenizer = tokenizer
52
+ self.device = device
53
+
54
+ def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]:
55
+ print("############ 切分文本 ############")
56
+ texts = self.pre_seg_text(text, lang, text_split_method)
57
+ result = []
58
+ print("############ 提取文本Bert特征 ############")
59
+ for text in tqdm(texts):
60
+ phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
61
+ if phones is None:
62
+ continue
63
+ res={
64
+ "phones": phones,
65
+ "bert_features": bert_features,
66
+ "norm_text": norm_text,
67
+ }
68
+ result.append(res)
69
+ return result
70
+
71
+ def pre_seg_text(self, text:str, lang:str, text_split_method:str):
72
+ text = text.strip("\n")
73
+ if (text[0] not in splits and len(get_first(text)) < 4):
74
+ text = "。" + text if lang != "en" else "." + text
75
+ print("实际输入的目标文本:")
76
+ print(text)
77
+
78
+ seg_method = get_seg_method(text_split_method)
79
+ text = seg_method(text)
80
+
81
+ while "\n\n" in text:
82
+ text = text.replace("\n\n", "\n")
83
+
84
+ _texts = text.split("\n")
85
+ _texts = merge_short_text_in_array(_texts, 5)
86
+ texts = []
87
+
88
+
89
+ for text in _texts:
90
+ # 解决输入目标文本的空行导致报错的问题
91
+ if (len(text.strip()) == 0):
92
+ continue
93
+ if (text[-1] not in splits): text += "。" if lang != "en" else "."
94
+
95
+ # 解决句子过长导致Bert报错的问题
96
+ if (len(text) > 510):
97
+ texts.extend(split_big_text(text))
98
+ else:
99
+ texts.append(text)
100
+
101
+ print("实际输入的目标文本(切句后):")
102
+ print(texts)
103
+ return texts
104
+
105
+ def segment_and_extract_feature_for_text(self, texts:list, language:str)->Tuple[list, torch.Tensor, str]:
106
+ textlist, langlist = self.seg_text(texts, language)
107
+ if len(textlist) == 0:
108
+ return None, None, None
109
+
110
+ phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist)
111
+ return phones, bert_features, norm_text
112
+
113
+
114
+ def seg_text(self, text:str, language:str)->Tuple[list, list]:
115
+
116
+ textlist=[]
117
+ langlist=[]
118
+ if language in ["auto", "zh", "ja"]:
119
+ LangSegment.setfilters(["zh","ja","en","ko"])
120
+ for tmp in LangSegment.getTexts(text):
121
+ if tmp["text"] == "":
122
+ continue
123
+ if tmp["lang"] == "ko":
124
+ langlist.append("zh")
125
+ elif tmp["lang"] == "en":
126
+ langlist.append("en")
127
+ else:
128
+ # 因无法区别中日文汉字,以用户输入为准
129
+ langlist.append(language if language!="auto" else tmp["lang"])
130
+ textlist.append(tmp["text"])
131
+ elif language == "en":
132
+ LangSegment.setfilters(["en"])
133
+ formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
134
+ while " " in formattext:
135
+ formattext = formattext.replace(" ", " ")
136
+ if formattext != "":
137
+ textlist.append(formattext)
138
+ langlist.append("en")
139
+
140
+ elif language in ["all_zh","all_ja"]:
141
+
142
+ formattext = text
143
+ while " " in formattext:
144
+ formattext = formattext.replace(" ", " ")
145
+ language = language.replace("all_","")
146
+ if text == "":
147
+ return [],[]
148
+ textlist.append(formattext)
149
+ langlist.append(language)
150
+
151
+ else:
152
+ raise ValueError(f"language {language} not supported")
153
+
154
+ return textlist, langlist
155
+
156
+
157
+ def extract_bert_feature(self, textlist:list, langlist:list):
158
+ phones_list = []
159
+ bert_feature_list = []
160
+ norm_text_list = []
161
+ for i in range(len(textlist)):
162
+ lang = langlist[i]
163
+ phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang)
164
+ _bert_feature = self.get_bert_inf(phones, word2ph, norm_text, lang)
165
+ # phones_list.append(phones)
166
+ phones_list.extend(phones)
167
+ norm_text_list.append(norm_text)
168
+ bert_feature_list.append(_bert_feature)
169
+ bert_feature = torch.cat(bert_feature_list, dim=1)
170
+ # phones = sum(phones_list, [])
171
+ norm_text = ''.join(norm_text_list)
172
+ return phones_list, bert_feature, norm_text
173
+
174
+
175
+ def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
176
+ with torch.no_grad():
177
+ inputs = self.tokenizer(text, return_tensors="pt")
178
+ for i in inputs:
179
+ inputs[i] = inputs[i].to(self.device)
180
+ res = self.bert_model(**inputs, output_hidden_states=True)
181
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
182
+ assert len(word2ph) == len(text)
183
+ phone_level_feature = []
184
+ for i in range(len(word2ph)):
185
+ repeat_feature = res[i].repeat(word2ph[i], 1)
186
+ phone_level_feature.append(repeat_feature)
187
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
188
+ return phone_level_feature.T
189
+
190
+ def clean_text_inf(self, text:str, language:str):
191
+ phones, word2ph, norm_text = clean_text(text, language)
192
+ phones = cleaned_text_to_sequence(phones)
193
+ return phones, word2ph, norm_text
194
+
195
+ def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
196
+ language=language.replace("all_","")
197
+ if language == "zh":
198
+ feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
199
+ else:
200
+ feature = torch.zeros(
201
+ (1024, len(phones)),
202
+ dtype=torch.float32,
203
+ ).to(self.device)
204
+
205
+ return feature
206
+
207
+
208
+
209
+
TTS_infer_pack/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import TTS, text_segmentation_method
TTS_infer_pack/__pycache__/TTS.cpython-310.pyc ADDED
Binary file (21.7 kB). View file
 
TTS_infer_pack/__pycache__/TextPreprocessor.cpython-310.pyc ADDED
Binary file (6.15 kB). View file
 
TTS_infer_pack/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (199 Bytes). View file
 
TTS_infer_pack/__pycache__/text_segmentation_method.cpython-310.pyc ADDED
Binary file (3.67 kB). View file
 
TTS_infer_pack/text_segmentation_method.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+ import re
6
+ from typing import Callable
7
+ #from tools.i18n.i18n import I18nAuto
8
+
9
+ #i18n = I18nAuto()
10
+
11
+ METHODS = dict()
12
+
13
+ def get_method(name:str)->Callable:
14
+ method = METHODS.get(name, None)
15
+ if method is None:
16
+ raise ValueError(f"Method {name} not found")
17
+ return method
18
+
19
+ def register_method(name):
20
+ def decorator(func):
21
+ METHODS[name] = func
22
+ return func
23
+ return decorator
24
+
25
+ splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
26
+
27
+ def split_big_text(text, max_len=510):
28
+ # 定义全角和半角标点符号
29
+ punctuation = "".join(splits)
30
+
31
+ # 切割文本
32
+ segments = re.split('([' + punctuation + '])', text)
33
+
34
+ # 初始化结果列表和当前片段
35
+ result = []
36
+ current_segment = ''
37
+
38
+ for segment in segments:
39
+ # 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段
40
+ if len(current_segment + segment) > max_len:
41
+ result.append(current_segment)
42
+ current_segment = segment
43
+ else:
44
+ current_segment += segment
45
+
46
+ # 将最后一个片段加入结果列表
47
+ if current_segment:
48
+ result.append(current_segment)
49
+
50
+ return result
51
+
52
+
53
+
54
+ def split(todo_text):
55
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
56
+ if todo_text[-1] not in splits:
57
+ todo_text += "。"
58
+ i_split_head = i_split_tail = 0
59
+ len_text = len(todo_text)
60
+ todo_texts = []
61
+ while 1:
62
+ if i_split_head >= len_text:
63
+ break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
64
+ if todo_text[i_split_head] in splits:
65
+ i_split_head += 1
66
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
67
+ i_split_tail = i_split_head
68
+ else:
69
+ i_split_head += 1
70
+ return todo_texts
71
+
72
+
73
+ # 不切
74
+ @register_method("cut0")
75
+ def cut0(inp):
76
+ return inp
77
+
78
+
79
+ # 凑四句一切
80
+ @register_method("cut1")
81
+ def cut1(inp):
82
+ inp = inp.strip("\n")
83
+ inps = split(inp)
84
+ split_idx = list(range(0, len(inps), 4))
85
+ split_idx[-1] = None
86
+ if len(split_idx) > 1:
87
+ opts = []
88
+ for idx in range(len(split_idx) - 1):
89
+ opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
90
+ else:
91
+ opts = [inp]
92
+ return "\n".join(opts)
93
+
94
+ # 凑50字一切
95
+ @register_method("cut2")
96
+ def cut2(inp):
97
+ inp = inp.strip("\n")
98
+ inps = split(inp)
99
+ if len(inps) < 2:
100
+ return inp
101
+ opts = []
102
+ summ = 0
103
+ tmp_str = ""
104
+ for i in range(len(inps)):
105
+ summ += len(inps[i])
106
+ tmp_str += inps[i]
107
+ if summ > 50:
108
+ summ = 0
109
+ opts.append(tmp_str)
110
+ tmp_str = ""
111
+ if tmp_str != "":
112
+ opts.append(tmp_str)
113
+ # print(opts)
114
+ if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
115
+ opts[-2] = opts[-2] + opts[-1]
116
+ opts = opts[:-1]
117
+ return "\n".join(opts)
118
+
119
+ # 按中文句号。切
120
+ @register_method("cut3")
121
+ def cut3(inp):
122
+ inp = inp.strip("\n")
123
+ return "\n".join(["%s" % item for item in inp.strip("。").split("。")])
124
+
125
+ #按英文句号.切
126
+ @register_method("cut4")
127
+ def cut4(inp):
128
+ inp = inp.strip("\n")
129
+ return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
130
+
131
+ # 按标点符号切
132
+ # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
133
+ @register_method("cut5")
134
+ def cut5(inp):
135
+ # if not re.search(r'[^\w\s]', inp[-1]):
136
+ # inp += '。'
137
+ inp = inp.strip("\n")
138
+ punds = r'[,.;?!、,。?!;:…]'
139
+ items = re.split(f'({punds})', inp)
140
+ mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
141
+ # 在句子不存在符号或句尾无符号的时候保证文本完整
142
+ if len(items)%2 == 1:
143
+ mergeitems.append(items[-1])
144
+ opt = "\n".join(mergeitems)
145
+ return opt
146
+
147
+
148
+
149
+ if __name__ == '__main__':
150
+ method = get_method("cut5")
151
+ print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))
152
+
app.py CHANGED
@@ -29,6 +29,8 @@ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
29
  logging.getLogger("multipart").setLevel(logging.WARNING)
30
  from download import *
31
  download()
 
 
32
 
33
  if "_CUDA_VISIBLE_DEVICES" in os.environ:
34
  os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
@@ -64,533 +66,87 @@ is_half = eval(
64
  os.environ.get("is_half", "True" if torch.cuda.is_available() else "False")
65
  )
66
 
67
- tokenizer = AutoTokenizer.from_pretrained(bert_path)
68
- bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
69
- if is_half == True:
70
- bert_model = bert_model.half().to(device)
71
- else:
72
- bert_model = bert_model.to(device)
73
-
74
-
75
- def get_bert_feature(text, word2ph):
76
- with torch.no_grad():
77
- inputs = tokenizer(text, return_tensors="pt")
78
- for i in inputs:
79
- inputs[i] = inputs[i].to(device)
80
- res = bert_model(**inputs, output_hidden_states=True)
81
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
82
- assert len(word2ph) == len(text)
83
- phone_level_feature = []
84
- for i in range(len(word2ph)):
85
- repeat_feature = res[i].repeat(word2ph[i], 1)
86
- phone_level_feature.append(repeat_feature)
87
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
88
- return phone_level_feature.T
89
-
90
-
91
- class DictToAttrRecursive(dict):
92
- def __init__(self, input_dict):
93
- super().__init__(input_dict)
94
- for key, value in input_dict.items():
95
- if isinstance(value, dict):
96
- value = DictToAttrRecursive(value)
97
- self[key] = value
98
- setattr(self, key, value)
99
-
100
- def __getattr__(self, item):
101
- try:
102
- return self[item]
103
- except KeyError:
104
- raise AttributeError(f"Attribute {item} not found")
105
-
106
- def __setattr__(self, key, value):
107
- if isinstance(value, dict):
108
- value = DictToAttrRecursive(value)
109
- super(DictToAttrRecursive, self).__setitem__(key, value)
110
- super().__setattr__(key, value)
111
-
112
- def __delattr__(self, item):
113
- try:
114
- del self[item]
115
- except KeyError:
116
- raise AttributeError(f"Attribute {item} not found")
117
-
118
-
119
- ssl_model = cnhubert.get_model()
120
- if is_half == True:
121
- ssl_model = ssl_model.half().to(device)
122
- else:
123
- ssl_model = ssl_model.to(device)
124
-
125
-
126
- def change_sovits_weights(sovits_path):
127
- global vq_model, hps
128
- dict_s2 = torch.load(sovits_path, map_location="cpu")
129
- hps = dict_s2["config"]
130
- hps = DictToAttrRecursive(hps)
131
- hps.model.semantic_frame_rate = "25hz"
132
- vq_model = SynthesizerTrn(
133
- hps.data.filter_length // 2 + 1,
134
- hps.train.segment_size // hps.data.hop_length,
135
- n_speakers=hps.data.n_speakers,
136
- **hps.model
137
- )
138
- if ("pretrained" not in sovits_path):
139
- del vq_model.enc_q
140
- if is_half == True:
141
- vq_model = vq_model.half().to(device)
142
- else:
143
- vq_model = vq_model.to(device)
144
- vq_model.eval()
145
- print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
146
- with open("./sweight.txt", "w", encoding="utf-8") as f:
147
- f.write(sovits_path)
148
-
149
-
150
- change_sovits_weights(sovits_path)
151
-
152
-
153
- def change_gpt_weights(gpt_path):
154
- global hz, max_sec, t2s_model, config
155
- hz = 50
156
- dict_s1 = torch.load(gpt_path, map_location="cpu")
157
- config = dict_s1["config"]
158
- max_sec = config["data"]["max_sec"]
159
- t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
160
- t2s_model.load_state_dict(dict_s1["weight"])
161
- if is_half == True:
162
- t2s_model = t2s_model.half()
163
- t2s_model = t2s_model.to(device)
164
- t2s_model.eval()
165
- total = sum([param.nelement() for param in t2s_model.parameters()])
166
- print("Number of parameter: %.2fM" % (total / 1e6))
167
- with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
168
-
169
-
170
- change_gpt_weights(gpt_path)
171
-
172
-
173
- def get_spepc(hps, filename):
174
- audio = load_audio(filename, int(hps.data.sampling_rate))
175
- audio = torch.FloatTensor(audio)
176
- audio_norm = audio
177
- audio_norm = audio_norm.unsqueeze(0)
178
- spec = spectrogram_torch(
179
- audio_norm,
180
- hps.data.filter_length,
181
- hps.data.sampling_rate,
182
- hps.data.hop_length,
183
- hps.data.win_length,
184
- center=False,
185
- )
186
- return spec
187
-
188
 
189
  dict_language = {
190
- ("中文1"): "all_zh",#全部按中文识别
191
- ("English"): "en",#全部按英文识别#######不变
192
- ("日文1"): "all_ja",#全部按日文识别
193
- ("中文"): "zh",#按中英混合识别####不变
194
- ("日本語"): "ja",#按日英混合识别####不变
195
- ("混合"): "auto",#多语种启动切分识别语种
196
  }
197
 
 
 
 
 
 
 
 
 
198
 
199
- def splite_en_inf(sentence, language):
200
- pattern = re.compile(r'[a-zA-Z ]+')
201
- textlist = []
202
- langlist = []
203
- pos = 0
204
- for match in pattern.finditer(sentence):
205
- start, end = match.span()
206
- if start > pos:
207
- textlist.append(sentence[pos:start])
208
- langlist.append(language)
209
- textlist.append(sentence[start:end])
210
- langlist.append("en")
211
- pos = end
212
- if pos < len(sentence):
213
- textlist.append(sentence[pos:])
214
- langlist.append(language)
215
- # Merge punctuation into previous word
216
- for i in range(len(textlist)-1, 0, -1):
217
- if re.match(r'^[\W_]+$', textlist[i]):
218
- textlist[i-1] += textlist[i]
219
- del textlist[i]
220
- del langlist[i]
221
- # Merge consecutive words with the same language tag
222
- i = 0
223
- while i < len(langlist) - 1:
224
- if langlist[i] == langlist[i+1]:
225
- textlist[i] += textlist[i+1]
226
- del textlist[i+1]
227
- del langlist[i+1]
228
- else:
229
- i += 1
230
-
231
- return textlist, langlist
232
-
233
-
234
- def clean_text_inf(text, language):
235
- formattext = ""
236
- language = language.replace("all_","")
237
- for tmp in LangSegment.getTexts(text):
238
- if language == "ja":
239
- if tmp["lang"] == language or tmp["lang"] == "zh":
240
- formattext += tmp["text"] + " "
241
- continue
242
- if tmp["lang"] == language:
243
- formattext += tmp["text"] + " "
244
- while " " in formattext:
245
- formattext = formattext.replace(" ", " ")
246
- phones, word2ph, norm_text = clean_text(formattext, language)
247
- phones = cleaned_text_to_sequence(phones)
248
- return phones, word2ph, norm_text
249
-
250
- dtype=torch.float16 if is_half == True else torch.float32
251
- def get_bert_inf(phones, word2ph, norm_text, language):
252
- language=language.replace("all_","")
253
- if language == "zh":
254
- bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
255
- else:
256
- bert = torch.zeros(
257
- (1024, len(phones)),
258
- dtype=torch.float16 if is_half == True else torch.float32,
259
- ).to(device)
260
-
261
- return bert
262
-
263
-
264
- def nonen_clean_text_inf(text, language):
265
- if(language!="auto"):
266
- textlist, langlist = splite_en_inf(text, language)
267
- else:
268
- textlist=[]
269
- langlist=[]
270
- for tmp in LangSegment.getTexts(text):
271
- langlist.append(tmp["lang"])
272
- textlist.append(tmp["text"])
273
- print(textlist)
274
- print(langlist)
275
- phones_list = []
276
- word2ph_list = []
277
- norm_text_list = []
278
- for i in range(len(textlist)):
279
- lang = langlist[i]
280
- phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
281
- phones_list.append(phones)
282
- if lang == "zh":
283
- word2ph_list.append(word2ph)
284
- norm_text_list.append(norm_text)
285
- print(word2ph_list)
286
- phones = sum(phones_list, [])
287
- word2ph = sum(word2ph_list, [])
288
- norm_text = ' '.join(norm_text_list)
289
-
290
- return phones, word2ph, norm_text
291
-
292
-
293
- def nonen_get_bert_inf(text, language):
294
- if(language!="auto"):
295
- textlist, langlist = splite_en_inf(text, language)
296
- else:
297
- textlist=[]
298
- langlist=[]
299
- for tmp in LangSegment.getTexts(text):
300
- langlist.append(tmp["lang"])
301
- textlist.append(tmp["text"])
302
- print(textlist)
303
- print(langlist)
304
- bert_list = []
305
- for i in range(len(textlist)):
306
- lang = langlist[i]
307
- phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
308
- bert = get_bert_inf(phones, word2ph, norm_text, lang)
309
- bert_list.append(bert)
310
- bert = torch.cat(bert_list, dim=1)
311
-
312
- return bert
313
-
314
-
315
- splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
316
-
317
-
318
- def get_first(text):
319
- pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
320
- text = re.split(pattern, text)[0].strip()
321
- return text
322
-
323
-
324
- def get_cleaned_text_final(text,language):
325
- if language in {"en","all_zh","all_ja"}:
326
- phones, word2ph, norm_text = clean_text_inf(text, language)
327
- elif language in {"zh", "ja","auto"}:
328
- phones, word2ph, norm_text = nonen_clean_text_inf(text, language)
329
- return phones, word2ph, norm_text
330
-
331
- def get_bert_final(phones, word2ph, text,language,device):
332
- if language == "en":
333
- bert = get_bert_inf(phones, word2ph, text, language)
334
- elif language in {"zh", "ja","auto"}:
335
- bert = nonen_get_bert_inf(text, language)
336
- elif language == "all_zh":
337
- bert = get_bert_feature(text, word2ph).to(device)
338
- else:
339
- bert = torch.zeros((1024, len(phones))).to(device)
340
- return bert
341
-
342
- def merge_short_text_in_array(texts, threshold):
343
- if (len(texts)) < 2:
344
- return texts
345
- result = []
346
- text = ""
347
- for ele in texts:
348
- text += ele
349
- if len(text) >= threshold:
350
- result.append(text)
351
- text = ""
352
- if (len(text) > 0):
353
- if len(result) == 0:
354
- result.append(text)
355
- else:
356
- result[len(result) - 1] += text
357
- return result
358
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
- def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=("Do not split"), volume_scale=1.0):
361
- if not duration(ref_wav_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  return None
363
  if text == '':
364
- wprint("Please enter text to generate/请输入生成文字")
365
  return None
366
- t0 = ttime()
367
- startTime=timer()
368
- text=trim_text(text,text_language)
369
- change_sovits_weights(sovits_path)
370
- tprint(f'🏕️LOADED SoVITS Model: {sovits_path}')
371
- change_gpt_weights(gpt_path)
372
- tprint(f'🏕️LOADED GPT Model: {gpt_path}')
373
-
374
- prompt_language = dict_language[prompt_language]
375
  try:
376
- text_language = dict_language[text_language]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  except KeyError as e:
378
- wprint(f"Unsupported language type: {e}")
379
  return None
380
-
381
- prompt_text = prompt_text.strip("\n")
382
- if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
383
- text = text.strip("\n")
384
- if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
385
- #print(("实际输入的参考文本:"), prompt_text)
386
- #print(("📝实际输入的目标文本:"), text)
387
- zero_wav = np.zeros(
388
- int(hps.data.sampling_rate * 0.3),
389
- dtype=np.float16 if is_half == True else np.float32,
390
- )
391
- with torch.no_grad():
392
- wav16k, sr = librosa.load(ref_wav_path, sr=16000)
393
- if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
394
- errinfo='参考音频在3~10秒范围外,请更换!'
395
- raise OSError((errinfo))
396
- wav16k = torch.from_numpy(wav16k)
397
- zero_wav_torch = torch.from_numpy(zero_wav)
398
- if is_half == True:
399
- wav16k = wav16k.half().to(device)
400
- zero_wav_torch = zero_wav_torch.half().to(device)
401
- else:
402
- wav16k = wav16k.to(device)
403
- zero_wav_torch = zero_wav_torch.to(device)
404
- wav16k = torch.cat([wav16k, zero_wav_torch])
405
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
406
- "last_hidden_state"
407
- ].transpose(
408
- 1, 2
409
- ) # .float()
410
- codes = vq_model.extract_latent(ssl_content)
411
- prompt_semantic = codes[0, 0]
412
- t1 = ttime()
413
-
414
- phones1, word2ph1, norm_text1=get_cleaned_text_final(prompt_text, prompt_language)
415
-
416
- if (how_to_cut == ("Split into groups of 4 sentences")):
417
- text = cut1(text)
418
- elif (how_to_cut == ("Split every 50 characters")):
419
- text = cut2(text)
420
- elif (how_to_cut == ("Split at CN/JP periods (。)")):
421
- text = cut3(text)
422
- elif (how_to_cut == ("Split at English periods (.)")):
423
- text = cut4(text)
424
- elif (how_to_cut == ("Split at punctuation marks")):
425
- text = cut5(text)
426
- while "\n\n" in text:
427
- text = text.replace("\n\n", "\n")
428
- print(f"🧨实际输入的目标文本(切句后):{text}\n")
429
- texts = text.split("\n")
430
- texts = merge_short_text_in_array(texts, 5)
431
- audio_opt = []
432
- bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype)
433
-
434
- for text in texts:
435
- if (len(text.strip()) == 0):
436
- continue
437
- if (text[-1] not in splits): text += "。" if text_language != "en" else "."
438
- print(("\n🎈实际输入的目标文本(每句):"), text)
439
- phones2, word2ph2, norm_text2 = get_cleaned_text_final(text, text_language)
440
- try:
441
- bert2 = get_bert_final(phones2, word2ph2, norm_text2, text_language, device).to(dtype)
442
- except RuntimeError as e:
443
- wprint(f"The input text does not match the language/输入文本与语言不匹配: {e}")
444
- return None
445
- bert = torch.cat([bert1, bert2], 1)
446
-
447
- all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
448
- bert = bert.to(device).unsqueeze(0)
449
- all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
450
- prompt = prompt_semantic.unsqueeze(0).to(device)
451
- t2 = ttime()
452
- with torch.no_grad():
453
- # pred_semantic = t2s_model.model.infer(
454
- pred_semantic, idx = t2s_model.model.infer_panel(
455
- all_phoneme_ids,
456
- all_phoneme_len,
457
- prompt,
458
- bert,
459
- # prompt_phone_len=ph_offset,
460
- top_k=config["inference"]["top_k"],
461
- early_stop_num=hz * max_sec,
462
- )
463
- t3 = ttime()
464
- # print(pred_semantic.shape,idx)
465
- pred_semantic = pred_semantic[:, -idx:].unsqueeze(
466
- 0
467
- ) # .unsqueeze(0)#mq要多unsqueeze一次
468
- refer = get_spepc(hps, ref_wav_path) # .to(device)
469
- if is_half == True:
470
- refer = refer.half().to(device)
471
- else:
472
- refer = refer.to(device)
473
- # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
474
- try:
475
- audio = (
476
- vq_model.decode(
477
- pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
478
- )
479
- .detach()
480
- .cpu()
481
- .numpy()[0, 0]
482
- )
483
- except RuntimeError as e:
484
- wprint(f"The input text does not match the language/输入文本与语言不匹配: {e}")
485
- return None
486
-
487
- max_audio=np.abs(audio).max()
488
- if max_audio>1:audio/=max_audio
489
- audio_opt.append(audio)
490
- audio_opt.append(zero_wav)
491
- t4 = ttime()
492
- print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
493
- #yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
494
- audio_data = (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
495
-
496
- audio_data = (audio_data.astype(np.float32) * volume_scale).astype(np.int16)
497
- output_wav = "output_audio.wav"
498
- sf.write(output_wav, audio_data, hps.data.sampling_rate)
499
- endTime=timer()
500
- tprint(f'🆗TTS COMPLETE,{round(endTime-startTime,4)}s')
501
- return output_wav
502
-
503
- def split(todo_text):
504
- todo_text = todo_text.replace("……", "。").replace("——", ",")
505
- if todo_text[-1] not in splits:
506
- todo_text += "。"
507
- i_split_head = i_split_tail = 0
508
- len_text = len(todo_text)
509
- todo_texts = []
510
- while 1:
511
- if i_split_head >= len_text:
512
- break
513
- if todo_text[i_split_head] in splits:
514
- i_split_head += 1
515
- todo_texts.append(todo_text[i_split_tail:i_split_head])
516
- i_split_tail = i_split_head
517
- else:
518
- i_split_head += 1
519
- return todo_texts
520
-
521
-
522
- def cut1(inp):
523
- inp = inp.strip("\n")
524
- inps = split(inp)
525
- split_idx = list(range(0, len(inps), 4))
526
- split_idx[-1] = None
527
- if len(split_idx) > 1:
528
- opts = []
529
- for idx in range(len(split_idx) - 1):
530
- opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
531
- else:
532
- opts = [inp]
533
- return "\n".join(opts)
534
-
535
-
536
- def cut2(inp):
537
- inp = inp.strip("\n")
538
- inps = split(inp)
539
- if len(inps) < 2:
540
- return inp
541
- opts = []
542
- summ = 0
543
- tmp_str = ""
544
- for i in range(len(inps)):
545
- summ += len(inps[i])
546
- tmp_str += inps[i]
547
- if summ > 50:
548
- summ = 0
549
- opts.append(tmp_str)
550
- tmp_str = ""
551
- if tmp_str != "":
552
- opts.append(tmp_str)
553
- # print(opts)
554
- if len(opts) > 1 and len(opts[-1]) < 50:
555
- opts[-2] = opts[-2] + opts[-1]
556
- opts = opts[:-1]
557
- return "\n".join(opts)
558
-
559
-
560
- def cut3(inp):
561
- inp = inp.strip("\n")
562
- return "\n".join(["%s" % item for item in inp.strip("。").split("。")])
563
-
564
-
565
- def cut4(inp):
566
- inp = inp.strip("\n")
567
- return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
568
-
569
-
570
- # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
571
- def cut5(inp):
572
- # if not re.search(r'[^\w\s]', inp[-1]):
573
- # inp += '。'
574
- inp = inp.strip("\n")
575
- punds = r'[,.;?!、,。?!;:…]'
576
- items = re.split(f'({punds})', inp)
577
- mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
578
- if len(items)%2 == 1:
579
- mergeitems.append(items[-1])
580
- opt = "\n".join(mergeitems)
581
- return opt
582
-
583
-
584
-
585
- def custom_sort_key(s):
586
- # 使用正则表达式提取字符串中的数字部分和非数字部分
587
- parts = re.split('(\d+)', s)
588
- # 将数字部分转换为整数,非数字部分保持不变
589
- parts = [int(part) if part.isdigit() else part for part in parts]
590
- return parts
591
 
592
  #==========custom functions============
593
 
 
594
  def tprint(text):
595
  now=datetime.now(tz).strftime('%H:%M:%S')
596
  print(f'UTC+8 - {now} - {text}')
@@ -638,7 +194,7 @@ def trim_text(text,language):
638
  return ' '.join(words[:i+1])
639
  return ' '.join(words[:limit_en])
640
 
641
- else:#中文日文
642
  if len(text) <= limit_cj:
643
  return text
644
  for i in range(limit_cj, -1, -1):
@@ -663,10 +219,12 @@ def duration(audio_file_path):
663
  return False
664
 
665
  def update_model(choice):
666
- global gpt_path, sovits_path
667
  model_info = models[choice]
668
  gpt_path = abs_path(model_info["gpt_weight"])
669
  sovits_path = abs_path(model_info["sovits_weight"])
 
 
670
  model_name = choice
671
  tone_info = model_info["tones"]["tone1"]
672
  tone_sample_path = abs_path(tone_info["sample"])
@@ -708,7 +266,7 @@ def transcribe(voice):
708
 
709
  time2=timer()
710
  tprint(f'transcribe COMPLETE,{round(time2-time1,4)}s')
711
- tprint(f'\n🔣转录结果:\n 🔣Language:{language} \n 🔣Text:{text}' )
712
  return text,language
713
 
714
  def clone_voice(user_voice,user_text,user_lang):
@@ -717,31 +275,37 @@ def clone_voice(user_voice,user_text,user_lang):
717
  if user_text == '':
718
  wprint("Please enter text to generate/请输入生成文字")
719
  return None
720
- tprint('⚡Start clone')
721
  user_text=trim_text(user_text,user_lang)
722
- time1=timer()
723
  global gpt_path, sovits_path
724
  gpt_path = abs_path("pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
725
  #tprint(f'Model loaded:{gpt_path}')
726
  sovits_path = abs_path("pretrained_models/s2G488k.pth")
727
  #tprint(f'Model loaded:{sovits_path}')
728
  try:
729
- prompt_text, prompt_language = transcribe(user_voice)
730
  except UnboundLocalError as e:
731
  wprint(f"The language in the audio cannot be recognized :{str(e)}")
732
  return None
733
-
734
- output_wav = get_tts_wav(
735
- user_voice,
736
- prompt_text,
737
- prompt_language,
738
- user_text,
739
- user_lang,
740
- how_to_cut="Do not split",
741
- volume_scale=1.0)
742
- time2=timer()
743
- tprint(f'🆗CLONE COMPLETE,{round(time2-time1,4)}s')
744
- return output_wav
 
 
 
 
 
 
 
 
745
 
746
  with open('dummy') as f:
747
  dummy_txt = f.read().strip().splitlines()
@@ -829,15 +393,26 @@ with gr.Blocks(theme='Kasien/ali_theme_custom') as app:
829
 
830
 
831
  with gr.Accordion(label="Additional generation options/附加生成选项", open=False):
832
- how_to_cut = gr.Dropdown(
833
- label=("How to split?"),
834
- choices=[("Do not split"), ("Split into groups of 4 sentences"), ("Split every 50 characters"),
835
- ("Split at CN/JP periods (。)"), ("Split at English periods (.)"), ("Split at punctuation marks"), ],
836
- value=("Split into groups of 4 sentences"),
 
837
  interactive=True,
838
- info='A suitable splitting method can achieve better generation results'
839
  )
840
- volume = gr.Slider(minimum=0.5, maximum=2, value=1, step=0.01, label='Volume/音量')
 
 
 
 
 
 
 
 
 
 
841
 
842
 
843
  gr.HTML('''
@@ -860,8 +435,12 @@ with gr.Blocks(theme='Kasien/ali_theme_custom') as app:
860
  user_voice = gr.Audio(type="filepath", label="(3~10s)Upload or Record audio/上传或录制声音",scale=3)
861
  with gr.Column(scale=7):
862
  user_lang = gr.Textbox(label="Language/生成语言",info='Automatic detection of input language type.',interactive=False)
863
- user_text= gr.Textbox(label="Text for generation/输入想要生成语音的文字", lines=5,
864
- placeholder=plsh,info=limit)
 
 
 
 
865
  user_text.change( lang_detector, user_text, user_lang)
866
 
867
  user_button = gr.Button("✨Clone Voice", variant="primary")
@@ -875,9 +454,23 @@ with gr.Blocks(theme='Kasien/ali_theme_custom') as app:
875
  tone_select.change(update_tone, inputs=[model_name, tone_select], outputs=[inp_ref, prompt_text, tone_sample])
876
 
877
  main_button.click(
878
- get_tts_wav,
879
- inputs=[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut,volume],
880
- outputs=[output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
881
 
882
  user_button.click(
883
  clone_voice,
 
29
  logging.getLogger("multipart").setLevel(logging.WARNING)
30
  from download import *
31
  download()
32
+ from TTS_infer_pack.TTS import TTS, TTS_Config
33
+ from TTS_infer_pack.text_segmentation_method import get_method
34
 
35
  if "_CUDA_VISIBLE_DEVICES" in os.environ:
36
  os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
 
66
  os.environ.get("is_half", "True" if torch.cuda.is_available() else "False")
67
  )
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  dict_language = {
71
+ "中文1": "all_zh",
72
+ "English": "en",
73
+ "日文1": "all_ja",
74
+ "中文": "zh",
75
+ "日本語": "ja",
76
+ "混合": "auto",
77
  }
78
 
79
+ cut_method = {
80
+ "Do not split/不切":"cut0",
81
+ "Split into groups of 4 sentences/四句一切": "cut1",
82
+ "Split every 50 characters/50字一切": "cut2",
83
+ "Split at CN/JP periods (。)/按中日文句号切": "cut3",
84
+ "Split at English periods (.)/按英文句号切": "cut4",
85
+ "Split at punctuation marks/按标点切": "cut5",
86
+ }
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
90
+ tts_config.device = device
91
+ tts_config.is_half = is_half
92
+ if gpt_path is not None:
93
+ tts_config.t2s_weights_path = gpt_path
94
+ if sovits_path is not None:
95
+ tts_config.vits_weights_path = sovits_path
96
+ if cnhubert_base_path is not None:
97
+ tts_config.cnhuhbert_base_path = cnhubert_base_path
98
+ if bert_path is not None:
99
+ tts_config.bert_base_path = bert_path
100
 
101
+
102
+ tts_pipline = TTS(tts_config)
103
+ gpt_path = tts_config.t2s_weights_path
104
+ sovits_path = tts_config.vits_weights_path
105
+
106
+
107
+ def inference(text, text_lang,
108
+ ref_audio_path, prompt_text,
109
+ prompt_lang, top_k,
110
+ top_p, temperature,
111
+ text_split_method, batch_size,
112
+ speed_factor, ref_text_free,
113
+ split_bucket,
114
+ volume
115
+ ):
116
+
117
+ if not duration(ref_audio_path):
118
  return None
119
  if text == '':
120
+ wprint("Please input text to generate/请输入生成文字")
121
  return None
122
+ text=trim_text(text,text_language)
 
 
 
 
 
 
 
 
123
  try:
124
+ lang=dict_language[text_lang]
125
+ inputs={
126
+ "text": text,
127
+ "text_lang": lang,
128
+ "ref_audio_path": ref_audio_path,
129
+ "prompt_text": prompt_text if not ref_text_free else "",
130
+ "prompt_lang": dict_language[prompt_lang],
131
+ "top_k": top_k,
132
+ "top_p": top_p,
133
+ "temperature": temperature,
134
+ "text_split_method": cut_method[text_split_method],
135
+ "batch_size":int(batch_size),
136
+ "speed_factor":float(speed_factor),
137
+ "split_bucket":split_bucket,
138
+ "volume":volume,
139
+ "return_fragment":False,
140
+ }
141
+
142
+ yield next(tts_pipline.run(inputs))
143
  except KeyError as e:
144
+ wprint(f'Unsupported language type:{e}')
145
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  #==========custom functions============
148
 
149
+ splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
150
  def tprint(text):
151
  now=datetime.now(tz).strftime('%H:%M:%S')
152
  print(f'UTC+8 - {now} - {text}')
 
194
  return ' '.join(words[:i+1])
195
  return ' '.join(words[:limit_en])
196
 
197
+ else:
198
  if len(text) <= limit_cj:
199
  return text
200
  for i in range(limit_cj, -1, -1):
 
219
  return False
220
 
221
  def update_model(choice):
222
+ #global tts_config.vits_weights_path, tts_config.t2s_weights_path
223
  model_info = models[choice]
224
  gpt_path = abs_path(model_info["gpt_weight"])
225
  sovits_path = abs_path(model_info["sovits_weight"])
226
+ tts_pipline.init_vits_weights(sovits_path)
227
+ tts_pipline.init_t2s_weights(gpt_path)
228
  model_name = choice
229
  tone_info = model_info["tones"]["tone1"]
230
  tone_sample_path = abs_path(tone_info["sample"])
 
266
 
267
  time2=timer()
268
  tprint(f'transcribe COMPLETE,{round(time2-time1,4)}s')
269
+ tprint(f' \nTranscribe result:\n 🔣Language:{language} \n 🔣Text:{text}' )
270
  return text,language
271
 
272
  def clone_voice(user_voice,user_text,user_lang):
 
275
  if user_text == '':
276
  wprint("Please enter text to generate/请输入生成文字")
277
  return None
 
278
  user_text=trim_text(user_text,user_lang)
 
279
  global gpt_path, sovits_path
280
  gpt_path = abs_path("pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
281
  #tprint(f'Model loaded:{gpt_path}')
282
  sovits_path = abs_path("pretrained_models/s2G488k.pth")
283
  #tprint(f'Model loaded:{sovits_path}')
284
  try:
285
+ prompt_text, prompt_lang = transcribe(user_voice)
286
  except UnboundLocalError as e:
287
  wprint(f"The language in the audio cannot be recognized :{str(e)}")
288
  return None
289
+ tts_pipline.init_vits_weights(sovits_path)
290
+ tts_pipline.init_t2s_weights(gpt_path)
291
+ inputs={
292
+ "text": user_text,
293
+ "text_lang": dict_language[user_lang],
294
+ "ref_audio_path": user_voice,
295
+ "prompt_text": prompt_text,
296
+ "prompt_lang": dict_language[prompt_lang],
297
+ "top_k": 5,
298
+ "top_p": 1,
299
+ "temperature": 1,
300
+ "text_split_method": "cut1",
301
+ "batch_size":20,
302
+ "speed_factor":1.0,
303
+ "split_bucket":True,
304
+ "volume":1.0,
305
+ "return_fragment":False,
306
+ }
307
+
308
+ yield next(tts_pipline.run(inputs))
309
 
310
  with open('dummy') as f:
311
  dummy_txt = f.read().strip().splitlines()
 
393
 
394
 
395
  with gr.Accordion(label="Additional generation options/附加生成选项", open=False):
396
+ with gr.Row():
397
+ how_to_cut = gr.Dropdown(
398
+ label=("How to split input text?/如何对输入文字切片"),
399
+ choices=[("Do not split/不切"), ("Split into groups of 4 sentences/四句一切"), ("Split every 50 characters/50字一切"),
400
+ ("Split at CN/JP periods (。)/按中日文句号切"), ("Split at English periods (.)/按英文句号切"), ("Split at punctuation marks/按标点切"), ],
401
+ value=("Split into groups of 4 sentences/四句一切"),
402
  interactive=True,
403
+ info='A suitable splitting method can achieve better generation results/适合的切片方法会得到更好的效果'
404
  )
405
+ split_bucket = gr.Checkbox(label="Split bucket/数据分桶", value=True, info='Speed up the inference process/提升推理速度')
406
+ with gr.Row():
407
+ volume = gr.Slider(minimum=0.5, maximum=5, value=1, step=0.1, label='Volume/音量',info='audio distortion due to excessive volume/大了要爆音')
408
+ speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="Speed factor",value=1.0,info='Playback speed/播放速度')
409
+ batch_size = gr.Slider(minimum=1,maximum=100,step=1,label="Batch size",value=20,info='The number of sentences for batch inference./并行推理的句子数量')
410
+ with gr.Row():
411
+ top_k = gr.Slider(minimum=1,maximum=100,step=1,label="top_k",value=5)
412
+ top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label="top_p",value=1)
413
+ temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label="temperature",value=1)
414
+ ref_text_free = gr.Checkbox(label="REF_TEXT_FREE", value=False, visible=False)
415
+
416
 
417
 
418
  gr.HTML('''
 
435
  user_voice = gr.Audio(type="filepath", label="(3~10s)Upload or Record audio/上传或录制声音",scale=3)
436
  with gr.Column(scale=7):
437
  user_lang = gr.Textbox(label="Language/生成语言",info='Automatic detection of input language type.',interactive=False)
438
+ with gr.Row():
439
+ user_text= gr.Textbox(label="Text for generation/输入想要生成语音的文字", lines=5,placeholder=plsh,info=limit)
440
+ dddice= gr.Button('🎲', variant='tool',min_width=0,scale=0)
441
+
442
+ dddice.click(dice, outputs=[user_text, dddice])
443
+
444
  user_text.change( lang_detector, user_text, user_lang)
445
 
446
  user_button = gr.Button("✨Clone Voice", variant="primary")
 
454
  tone_select.change(update_tone, inputs=[model_name, tone_select], outputs=[inp_ref, prompt_text, tone_sample])
455
 
456
  main_button.click(
457
+ inference,
458
+ inputs=[text,
459
+ text_language,
460
+ inp_ref,
461
+ prompt_text,
462
+ prompt_language,
463
+ top_k,
464
+ top_p,
465
+ temperature,
466
+ how_to_cut,
467
+ batch_size,
468
+ speed_factor,
469
+ ref_text_free,
470
+ split_bucket,
471
+ volume],
472
+ outputs=[output]
473
+ )
474
 
475
  user_button.click(
476
  clone_voice,
feature_extractor/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/feature_extractor/__pycache__/__init__.cpython-310.pyc and b/feature_extractor/__pycache__/__init__.cpython-310.pyc differ
 
feature_extractor/__pycache__/cnhubert.cpython-310.pyc CHANGED
Binary files a/feature_extractor/__pycache__/cnhubert.cpython-310.pyc and b/feature_extractor/__pycache__/cnhubert.cpython-310.pyc differ
 
feature_extractor/__pycache__/whisper_enc.cpython-310.pyc CHANGED
Binary files a/feature_extractor/__pycache__/whisper_enc.cpython-310.pyc and b/feature_extractor/__pycache__/whisper_enc.cpython-310.pyc differ
 
feature_extractor/cnhubert.py CHANGED
@@ -4,9 +4,9 @@ import librosa
4
  import torch
5
  import torch.nn.functional as F
6
  import soundfile as sf
7
- #import logging
8
 
9
- #logging.getLogger("numba").setLevel(logging.WARNING)
10
 
11
  from transformers import (
12
  Wav2Vec2FeatureExtractor,
@@ -20,13 +20,16 @@ cnhubert_base_path = None
20
 
21
 
22
  class CNHubert(nn.Module):
23
- def __init__(self):
24
  super().__init__()
25
- self.model = HubertModel.from_pretrained(cnhubert_base_path)
 
 
26
  self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
27
- cnhubert_base_path
28
  )
29
 
 
30
  def forward(self, x):
31
  input_values = self.feature_extractor(
32
  x, return_tensors="pt", sampling_rate=16000
 
4
  import torch
5
  import torch.nn.functional as F
6
  import soundfile as sf
7
+ import logging
8
 
9
+ logging.getLogger("numba").setLevel(logging.WARNING)
10
 
11
  from transformers import (
12
  Wav2Vec2FeatureExtractor,
 
20
 
21
 
22
  class CNHubert(nn.Module):
23
+ def __init__(self, base_path:str=None):
24
  super().__init__()
25
+ if base_path is None:
26
+ base_path = cnhubert_base_path
27
+ self.model = HubertModel.from_pretrained(base_path)
28
  self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
29
+ base_path
30
  )
31
 
32
+
33
  def forward(self, x):
34
  input_values = self.feature_extractor(
35
  x, return_tensors="pt", sampling_rate=16000
module/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/__init__.cpython-310.pyc and b/module/__pycache__/__init__.cpython-310.pyc differ
 
module/__pycache__/attentions.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/attentions.cpython-310.pyc and b/module/__pycache__/attentions.cpython-310.pyc differ
 
module/__pycache__/commons.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/commons.cpython-310.pyc and b/module/__pycache__/commons.cpython-310.pyc differ
 
module/__pycache__/core_vq.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/core_vq.cpython-310.pyc and b/module/__pycache__/core_vq.cpython-310.pyc differ
 
module/__pycache__/mel_processing.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/mel_processing.cpython-310.pyc and b/module/__pycache__/mel_processing.cpython-310.pyc differ
 
module/__pycache__/models.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/models.cpython-310.pyc and b/module/__pycache__/models.cpython-310.pyc differ
 
module/__pycache__/modules.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/modules.cpython-310.pyc and b/module/__pycache__/modules.cpython-310.pyc differ
 
module/__pycache__/mrte_model.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/mrte_model.cpython-310.pyc and b/module/__pycache__/mrte_model.cpython-310.pyc differ
 
module/__pycache__/quantize.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/quantize.cpython-310.pyc and b/module/__pycache__/quantize.cpython-310.pyc differ
 
module/__pycache__/transforms.cpython-310.pyc CHANGED
Binary files a/module/__pycache__/transforms.cpython-310.pyc and b/module/__pycache__/transforms.cpython-310.pyc differ
 
module/models.py CHANGED
@@ -1,5 +1,6 @@
1
  import copy
2
  import math
 
3
  import torch
4
  from torch import nn
5
  from torch.nn import functional as F
@@ -228,6 +229,7 @@ class TextEncoder(nn.Module):
228
  )
229
 
230
  y = self.ssl_proj(y * y_mask) * y_mask
 
231
  y = self.encoder_ssl(y * y_mask, y_mask)
232
 
233
  text_mask = torch.unsqueeze(
@@ -958,11 +960,13 @@ class SynthesizerTrn(nn.Module):
958
 
959
  @torch.no_grad()
960
  def decode(self, codes, text, refer, noise_scale=0.5):
961
- refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
962
- refer_mask = torch.unsqueeze(
963
- commons.sequence_mask(refer_lengths, refer.size(2)), 1
964
- ).to(refer.dtype)
965
- ge = self.ref_enc(refer * refer_mask, refer_mask)
 
 
966
 
967
  y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
968
  text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
@@ -982,6 +986,55 @@ class SynthesizerTrn(nn.Module):
982
 
983
  o = self.dec((z * y_mask)[:, :, :], g=ge)
984
  return o
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985
 
986
  def extract_latent(self, x):
987
  ssl = self.ssl_proj(x)
 
1
  import copy
2
  import math
3
+ from typing import List
4
  import torch
5
  from torch import nn
6
  from torch.nn import functional as F
 
229
  )
230
 
231
  y = self.ssl_proj(y * y_mask) * y_mask
232
+
233
  y = self.encoder_ssl(y * y_mask, y_mask)
234
 
235
  text_mask = torch.unsqueeze(
 
960
 
961
  @torch.no_grad()
962
  def decode(self, codes, text, refer, noise_scale=0.5):
963
+ ge = None
964
+ if refer is not None:
965
+ refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
966
+ refer_mask = torch.unsqueeze(
967
+ commons.sequence_mask(refer_lengths, refer.size(2)), 1
968
+ ).to(refer.dtype)
969
+ ge = self.ref_enc(refer * refer_mask, refer_mask)
970
 
971
  y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
972
  text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
 
986
 
987
  o = self.dec((z * y_mask)[:, :, :], g=ge)
988
  return o
989
+
990
+
991
+ @torch.no_grad()
992
+ def batched_decode(self, codes, y_lengths, text, text_lengths, refer, noise_scale=0.5):
993
+ ge = None
994
+ if refer is not None:
995
+ refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
996
+ refer_mask = torch.unsqueeze(
997
+ commons.sequence_mask(refer_lengths, refer.size(2)), 1
998
+ ).to(refer.dtype)
999
+ ge = self.ref_enc(refer * refer_mask, refer_mask)
1000
+
1001
+ # y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, codes.size(2)), 1).to(
1002
+ # codes.dtype
1003
+ # )
1004
+ y_lengths = (y_lengths * 2).long().to(codes.device)
1005
+ text_lengths = text_lengths.long().to(text.device)
1006
+ # y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
1007
+ # text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
1008
+
1009
+ # 假设padding之后再decode没有问题, 影响未知,但听起来好像没问题?
1010
+ quantized = self.quantizer.decode(codes)
1011
+ if self.semantic_frame_rate == "25hz":
1012
+ quantized = F.interpolate(
1013
+ quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
1014
+ )
1015
+
1016
+ x, m_p, logs_p, y_mask = self.enc_p(
1017
+ quantized, y_lengths, text, text_lengths, ge
1018
+ )
1019
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1020
+
1021
+ z = self.flow(z_p, y_mask, g=ge, reverse=True)
1022
+ z_masked = (z * y_mask)[:, :, :]
1023
+
1024
+ # 串行。把padding部分去掉再decode
1025
+ o_list:List[torch.Tensor] = []
1026
+ for i in range(z_masked.shape[0]):
1027
+ z_slice = z_masked[i, :, :y_lengths[i]].unsqueeze(0)
1028
+ o = self.dec(z_slice, g=ge)[0, 0, :].detach()
1029
+ o_list.append(o)
1030
+
1031
+ # 并行(会有问题)。先decode,再把padding的部分去掉
1032
+ # o = self.dec(z_masked, g=ge)
1033
+ # upsample_rate = int(math.prod(self.upsample_rates))
1034
+ # o_lengths = y_lengths*upsample_rate
1035
+ # o_list = [o[i, 0, :idx].detach() for i, idx in enumerate(o_lengths)]
1036
+
1037
+ return o_list
1038
 
1039
  def extract_latent(self, x):
1040
  ssl = self.ssl_proj(x)