okewunmi commited on
Commit
c816d1a
·
verified ·
1 Parent(s): 79e0b37

Create yarngpt_utils.py

Browse files
Files changed (1) hide show
  1. yarngpt_utils.py +48 -0
yarngpt_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # yarngpt_utils.py
2
+ import torch
3
+ import torchaudio
4
+ from outetts.wav_tokenizer.decoder import WavTokenizer
5
+ from transformers import AutoTokenizer
6
+
7
+ class AudioTokenizer:
8
+ def __init__(self, hf_path, wav_tokenizer_model_path, wav_tokenizer_config_path):
9
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ self.tokenizer = AutoTokenizer.from_pretrained(hf_path)
11
+ self.wav_tokenizer = WavTokenizer(
12
+ checkpoint_path=wav_tokenizer_model_path,
13
+ config_path=wav_tokenizer_config_path,
14
+ device=self.device
15
+ )
16
+ self.speakers = ["idera", "emma", "jude", "osagie", "tayo", "zainab",
17
+ "joke", "regina", "remi", "umar", "chinenye"]
18
+
19
+ def create_prompt(self, text, speaker_name=None):
20
+ if speaker_name is None or speaker_name not in self.speakers:
21
+ speaker_name = self.speakers[torch.randint(0, len(self.speakers), (1,)).item()]
22
+
23
+ # Create a prompt similar to the original YarnGPT
24
+ prompt = f"<|system|>\nYou are a helpful assistant that speaks in {speaker_name}'s voice.\n<|user|>\nSpeak this text: {text}\n<|assistant|>"
25
+ return prompt
26
+
27
+ def tokenize_prompt(self, prompt):
28
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
29
+ return input_ids
30
+
31
+ def get_codes(self, output):
32
+ # Decode the sequence
33
+ decoded_str = self.tokenizer.decode(output[0])
34
+
35
+ # Extract the part after <|assistant|>
36
+ speech_part = decoded_str.split("<|assistant|>")[-1].strip()
37
+
38
+ # Extract code tokens - assuming format like "<audio_001>"
39
+ audio_codes = []
40
+ for match in re.finditer(r"<audio_(\d+)>", speech_part):
41
+ code = int(match.group(1))
42
+ audio_codes.append(code)
43
+
44
+ return audio_codes
45
+
46
+ def get_audio(self, codes):
47
+ audio = self.wav_tokenizer.decode(torch.tensor(codes, device=self.device))
48
+ return audio