zwxl commited on
Commit
daf4e8e
·
1 Parent(s): 125d85c
app.py CHANGED
@@ -1,7 +1,45 @@
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
  import gradio as gr
2
+ import sys
3
+ from viitor_voice.inference.transformers_engine import TransformersEngine
4
+ import spaces
5
 
6
+ if __name__ == '__main__':
7
+ # Initialize your OfflineInference class with the appropriate paths
8
+ offline_inference = TransformersEngine("ZzWater/viitor-voice-mix")
9
+
10
+
11
+ @spaces.GPU
12
+ def clone_batch(text_list, prompt_audio, prompt_text):
13
+ print(prompt_audio.name)
14
+ try:
15
+ audios = offline_inference.batch_infer(
16
+ text_list=[text_list],
17
+ prompt_audio_path=prompt_audio.name, # Use uploaded file's path
18
+ prompt_text=prompt_text,
19
+ )
20
+ return 24000, audios[0].cpu().numpy()[0].astype('float32')
21
+ except Exception as e:
22
+ return str(e)
23
+
24
+
25
+ with gr.Blocks() as demo:
26
+ gr.Markdown("# TTS Inference Interface")
27
+ with gr.Tab("Batch Clone"):
28
+ gr.Markdown("### Batch Clone TTS")
29
+
30
+ text_list_clone = gr.Textbox(label="Input Text List (Comma-Separated)",
31
+ placeholder="Enter text1, text2, text3...")
32
+ prompt_audio = gr.File(label="Upload Prompt Audio")
33
+ prompt_text = gr.Textbox(label="Prompt Text", placeholder="Enter the prompt text")
34
+
35
+ clone_button = gr.Button("Run Batch Clone")
36
+ clone_output = gr.Audio(label="Generated Audios", type="numpy")
37
+
38
+ clone_button.click(
39
+ fn=clone_batch,
40
+ inputs=[text_list_clone, prompt_audio, prompt_text],
41
+ outputs=clone_output
42
+ )
43
+
44
+ demo.launch()
45
 
 
 
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ requests
2
+ accelerate==1.1.1
3
+ datasets==3.1.0
4
+ transformers
5
+ tokenizers
6
+ snac
7
+ torch==2.4.0
8
+ torchaudio==2.4.0
9
+ soundfile
viitor_voice/inference/common.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from io import BytesIO
4
+ from urllib.parse import urlparse
5
+ import requests
6
+ import torchaudio
7
+
8
+
9
+ def load_audio(source):
10
+ def is_url(path):
11
+ try:
12
+ result = urlparse(path)
13
+ return all([result.scheme, result.netloc])
14
+ except Exception:
15
+ return False
16
+
17
+ if is_url(source):
18
+ # 从 URL 加载音频
19
+ response = requests.get(source)
20
+ response.raise_for_status() # 检查请求状态
21
+ audio_data = BytesIO(response.content) # 转为类文件对象
22
+ else:
23
+ # 从本地文件加载音频
24
+ if not os.path.exists(source):
25
+ raise FileNotFoundError(f"File not found: {source}")
26
+ audio_data = source # 本地路径可以直接传递给 torchaudio.load
27
+
28
+ # 使用 torchaudio 加载音频
29
+ waveform, sample_rate = torchaudio.load(audio_data)
30
+ return waveform, sample_rate
31
+
32
+
33
+ pattern = re.compile(r"<\|speech-(\d+)\|>")
34
+
35
+
36
+ def combine_sequences(first_elements, second_elements, third_elements):
37
+ group_size = 7
38
+ sequence = []
39
+
40
+ second_index = 0
41
+ third_index = 0
42
+
43
+ for first in first_elements:
44
+ group = [None] * group_size
45
+
46
+ # Assign the first element
47
+ group[0] = first
48
+
49
+ # Assign the second and fifth elements if they exist
50
+ if second_index < len(second_elements):
51
+ group[1] = second_elements[second_index]
52
+ second_index += 1
53
+ if second_index < len(second_elements):
54
+ group[4] = second_elements[second_index]
55
+ second_index += 1
56
+
57
+ # Assign the remaining elements from third_elements if they exist
58
+ for j in [2, 3, 5, 6]:
59
+ if third_index < len(third_elements):
60
+ group[j] = third_elements[third_index]
61
+ third_index += 1
62
+
63
+ # Remove None values at the end of the group if the group is incomplete
64
+ sequence.extend([x for x in group if x is not None])
65
+
66
+ return sequence
67
+
68
+
69
+ def split_sequence(sequence):
70
+ group_size = 7
71
+ first_elements = []
72
+ second_elements = []
73
+ third_elements = []
74
+
75
+ # Iterate over the sequence in chunks of 7
76
+ for i in range(0, len(sequence), group_size):
77
+ group = sequence[i:i + group_size]
78
+
79
+ # Add elements to the respective lists based on their position in the group
80
+ if len(group) >= 1:
81
+ first_elements.append(group[0])
82
+ if len(group) >= 5:
83
+ second_elements.extend([group[1], group[4]])
84
+ if len(group) >= 7:
85
+ third_elements.extend([group[2], group[3], group[5], group[6]])
86
+ else:
87
+ third_elements.extend(group[2:])
88
+
89
+ return first_elements, second_elements, third_elements
90
+
viitor_voice/inference/transformers_engine.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchaudio
4
+ from snac import SNAC
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from viitor_voice.inference.common import combine_sequences, load_audio, pattern, split_sequence
7
+
8
+
9
+ class TransformersEngine:
10
+ def __init__(self, model_path, device='cuda'):
11
+ self.device = device
12
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
13
+ self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).to(device)
14
+ self.snac_model = SNAC.from_pretrained('hubertsiuzdak/snac_24khz').eval().to(device)
15
+
16
+ def batch_infer(self, text_list, prompt_audio_path, prompt_text, flattened_snac_encode=None):
17
+ if flattened_snac_encode is None:
18
+ prompt_audio, sr = load_audio(prompt_audio_path)
19
+ if sr != 24000:
20
+ prompt_audio = torchaudio.functional.resample(prompt_audio, sr, 24000)
21
+
22
+ snac_encode = self.snac_model.encode(prompt_audio[None,].to(self.device))
23
+ first_elements, second_elements, third_elements = \
24
+ snac_encode[0].cpu().numpy().tolist(), snac_encode[1].cpu().numpy().tolist(), snac_encode[
25
+ 2].cpu().numpy().tolist()
26
+ flattened_snac_encode = combine_sequences(first_elements[0], second_elements[0], third_elements[0])
27
+ prompt_snac_texts = ''.join(
28
+ ['<|speech-{}|>'.format(i) if j % 7 != 0 else '<|SEP_AUDIO|><|speech-{}|>'.format(i) for
29
+ j, i in
30
+ enumerate(flattened_snac_encode)])
31
+
32
+ prompts = [
33
+ '<|START_TEXT|>' + prompt_text + x + '<|END_TEXT|>' + '<|START_AUDIO|>' + prompt_snac_texts + '<|SEP_AUDIO|>'
34
+ for x in text_list]
35
+ prompt_ids_list = self.tokenizer(prompts, add_special_tokens=False).input_ids
36
+ results = []
37
+ for prompt_ids in prompt_ids_list:
38
+ prompt_ids = torch.tensor([prompt_ids], dtype=torch.int64).to(self.device)
39
+ output_ids = self.model.generate(prompt_ids, eos_token_id=156008, no_repeat_ngram_size=0, num_beams=1,
40
+ do_sample=False, repetition_penalty=1.3,
41
+ suppress_tokens=list(range(151641)))
42
+ output_ids = output_ids[0, prompt_ids.shape[-1]:].cpu().numpy().tolist()
43
+ generated_text = self.tokenizer.batch_decode([output_ids], skip_special_tokens=False)
44
+ snac_tokens = pattern.findall(generated_text)
45
+ snac_tokens = [int(x) for x in snac_tokens]
46
+ results.append(snac_tokens)
47
+ audios = self.batch_decode_audios(results)
48
+ return audios
49
+
50
+ def batch_decode_audios(self, snac_tokens_list):
51
+ audios = []
52
+ with torch.no_grad():
53
+ for snac_tokens in snac_tokens_list:
54
+ try:
55
+ first_elements, second_elements, third_elements = split_sequence(snac_tokens)
56
+ codes = [torch.from_numpy(np.array(x).astype(np.int32)[None,]).to(self.device) for x in
57
+ [first_elements, second_elements, third_elements]]
58
+ audio_hat_all = self.snac_model.decode(codes)[0].cpu()
59
+ audios.append(audio_hat_all.to(torch.float32))
60
+ except:
61
+ audios.append('error')
62
+ print('error')
63
+ return audios
64
+