Create yarngpt_utils.py
Browse files- yarngpt_utils.py +48 -0
yarngpt_utils.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# yarngpt_utils.py
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
from outetts.wav_tokenizer.decoder import WavTokenizer
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
|
7 |
+
class AudioTokenizer:
|
8 |
+
def __init__(self, hf_path, wav_tokenizer_model_path, wav_tokenizer_config_path):
|
9 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(hf_path)
|
11 |
+
self.wav_tokenizer = WavTokenizer(
|
12 |
+
checkpoint_path=wav_tokenizer_model_path,
|
13 |
+
config_path=wav_tokenizer_config_path,
|
14 |
+
device=self.device
|
15 |
+
)
|
16 |
+
self.speakers = ["idera", "emma", "jude", "osagie", "tayo", "zainab",
|
17 |
+
"joke", "regina", "remi", "umar", "chinenye"]
|
18 |
+
|
19 |
+
def create_prompt(self, text, speaker_name=None):
|
20 |
+
if speaker_name is None or speaker_name not in self.speakers:
|
21 |
+
speaker_name = self.speakers[torch.randint(0, len(self.speakers), (1,)).item()]
|
22 |
+
|
23 |
+
# Create a prompt similar to the original YarnGPT
|
24 |
+
prompt = f"<|system|>\nYou are a helpful assistant that speaks in {speaker_name}'s voice.\n<|user|>\nSpeak this text: {text}\n<|assistant|>"
|
25 |
+
return prompt
|
26 |
+
|
27 |
+
def tokenize_prompt(self, prompt):
|
28 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
29 |
+
return input_ids
|
30 |
+
|
31 |
+
def get_codes(self, output):
|
32 |
+
# Decode the sequence
|
33 |
+
decoded_str = self.tokenizer.decode(output[0])
|
34 |
+
|
35 |
+
# Extract the part after <|assistant|>
|
36 |
+
speech_part = decoded_str.split("<|assistant|>")[-1].strip()
|
37 |
+
|
38 |
+
# Extract code tokens - assuming format like "<audio_001>"
|
39 |
+
audio_codes = []
|
40 |
+
for match in re.finditer(r"<audio_(\d+)>", speech_part):
|
41 |
+
code = int(match.group(1))
|
42 |
+
audio_codes.append(code)
|
43 |
+
|
44 |
+
return audio_codes
|
45 |
+
|
46 |
+
def get_audio(self, codes):
|
47 |
+
audio = self.wav_tokenizer.decode(torch.tensor(codes, device=self.device))
|
48 |
+
return audio
|