bachvudinh commited on
Commit
87736a3
1 Parent(s): c4b9526

initial commit

Browse files
app copy.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ from encodec import EncodecModel
5
+ from whisperspeech.vq_stoks import RQBottleneckTransformer
6
+ from encodec.utils import convert_audio
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
8
+ from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
9
+ from threading import Thread
10
+ import logging
11
+ import os
12
+ from generate_audio import (
13
+ TTSProcessor,
14
+ )
15
+ import uuid
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ vq_model = RQBottleneckTransformer.load_model(
19
+ "whisper-vq-stoks-medium-en+pl-fixed.model"
20
+ ).to(device)
21
+ vq_model.ensure_whisper(device)
22
+
23
+ def audio_to_sound_tokens_whisperspeech(audio_path):
24
+ wav, sr = torchaudio.load(audio_path)
25
+ if sr != 16000:
26
+ wav = torchaudio.functional.resample(wav, sr, 16000)
27
+ with torch.no_grad():
28
+ codes = vq_model.encode_audio(wav.to(device))
29
+ codes = codes[0].cpu().tolist()
30
+
31
+ result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
32
+ return f'<|sound_start|>{result}<|sound_end|>'
33
+ def audio_to_sound_tokens_whisperspeech_transcribe(audio_path):
34
+ wav, sr = torchaudio.load(audio_path)
35
+ if sr != 16000:
36
+ wav = torchaudio.functional.resample(wav, sr, 16000)
37
+ with torch.no_grad():
38
+ codes = vq_model.encode_audio(wav.to(device))
39
+ codes = codes[0].cpu().tolist()
40
+
41
+ result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
42
+ return f'<|reserved_special_token_69|><|sound_start|>{result}<|sound_end|>'
43
+ def audio_to_sound_tokens(audio_path, target_bandwidth=1.5, device="cuda"):
44
+ model = EncodecModel.encodec_model_24khz()
45
+ model.set_target_bandwidth(target_bandwidth)
46
+ model.to(device)
47
+
48
+ wav, sr = torchaudio.load(audio_path)
49
+ wav = convert_audio(wav, sr, model.sample_rate, model.channels)
50
+ wav = wav.unsqueeze(0).to(device)
51
+
52
+ with torch.no_grad():
53
+ encoded_frames = model.encode(wav)
54
+ codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
55
+
56
+ audio_code1, audio_code2 = codes[0][0], codes[0][1]
57
+ flatten_tokens = torch.stack((audio_code1, audio_code2), dim=1).flatten().tolist()
58
+ result = ''.join(f'<|sound_{num:04d}|>' for num in flatten_tokens)
59
+ return f'<|sound_start|>{result}<|sound_end|>'
60
+
61
+ def setup_pipeline(model_path, use_4bit=False, use_8bit=False):
62
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
63
+ model_kwargs = {"device_map": "auto"}
64
+ if use_8bit:
65
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
66
+ load_in_8bit=True,
67
+ llm_int8_enable_fp32_cpu_offload=False,
68
+ llm_int8_has_fp16_weight=False,
69
+ )
70
+ else:
71
+ model_kwargs["torch_dtype"] = torch.bfloat16
72
+ model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
73
+ return pipeline("text-generation", model=model, tokenizer=tokenizer)
74
+
75
+ tts = TTSProcessor(device)
76
+ llm_path = "homebrewltd/Llama3.1-s-instruct-2024-08-19-epoch-3"
77
+ pipe = setup_pipeline(llm_path, use_8bit=False)
78
+ tokenizer = pipe.tokenizer
79
+ model = pipe.model
80
+ # print(tokenizer.encode("<|sound_0001|>", add_special_tokens=False))# return the audio tensor
81
+ # print(tokenizer.eos_token)
82
+ def text_to_audio_file(text):
83
+ # gen a random id for the audio file
84
+ id = str(uuid.uuid4())
85
+ temp_file = f"./user_audio/{id}_temp_audio.wav"
86
+ text = text
87
+ text_split = "_".join(text.lower().split(" "))
88
+ # remove the last character if it is a period
89
+ if text_split[-1] == ".":
90
+ text_split = text_split[:-1]
91
+ tts.convert_text_to_audio_file(text, temp_file)
92
+ # logging.info(f"Saving audio to {temp_file}")
93
+ # torchaudio.save(temp_file, audio.cpu(), sample_rate=24000)
94
+ print(f"Saved audio to {temp_file}")
95
+ return temp_file
96
+ def process_input(input_type, text_input=None, audio_file=None):
97
+ # if input_type == "text":
98
+ # audio_file = "temp_audio.wav"
99
+
100
+ for partial_message in process_audio(audio_file):
101
+ yield partial_message
102
+
103
+ # if input_type == "text":
104
+ # os.remove(audio_file)
105
+ def process_transcribe_input(input_type, text_input=None, audio_file=None):
106
+ # if input_type == "text":
107
+ # audio_file = "temp_audio.wav"
108
+
109
+ for partial_message in process_audio(audio_file, transcript=True):
110
+ yield partial_message
111
+
112
+ # if input_type == "text":
113
+ # os.remove(audio_file)
114
+ class StopOnTokens(StoppingCriteria):
115
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
116
+ # encode </s> token
117
+ stop_ids = [tokenizer.eos_token_id, 128009] # Adjust this based on your model's tokenizer
118
+ for stop_id in stop_ids:
119
+ if input_ids[0][-1] == stop_id:
120
+ return True
121
+ return False
122
+ def process_audio(audio_file, transcript=False):
123
+ if audio_file is None:
124
+ raise ValueError("No audio file provided")
125
+
126
+ logging.info(f"Audio file received: {audio_file}")
127
+ logging.info(f"Audio file type: {type(audio_file)}")
128
+
129
+ sound_tokens = audio_to_sound_tokens_whisperspeech_transcribe(audio_file) if transcript else audio_to_sound_tokens_whisperspeech(audio_file)
130
+ logging.info("Sound tokens generated successfully")
131
+ # logging.info(f"audio_file: {audio_file.name}")
132
+ messages = [
133
+ {"role": "user", "content": sound_tokens},
134
+ ]
135
+
136
+ stop = StopOnTokens()
137
+ input_str = tokenizer.apply_chat_template(messages, tokenize=False)
138
+ input_ids = tokenizer.encode(input_str, return_tensors="pt")
139
+ input_ids = input_ids.to(model.device)
140
+
141
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
142
+ generation_kwargs = dict(
143
+ input_ids=input_ids,
144
+ streamer=streamer,
145
+ max_new_tokens=1024,
146
+ do_sample=False,
147
+ stopping_criteria=StoppingCriteriaList([stop])
148
+ )
149
+
150
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
151
+ thread.start()
152
+
153
+ partial_message = ""
154
+ for new_token in streamer:
155
+ partial_message += new_token
156
+ if tokenizer.eos_token in partial_message:
157
+ break
158
+ partial_message = partial_message.replace("assistant\n\n", "")
159
+ yield partial_message
160
+ # def stop_generation():
161
+ # # This is a placeholder. Implement actual stopping logic here if needed.
162
+ # return "Generation stopped.", gr.Button.update(interactive=False)
163
+ # take all the examples from the examples folder
164
+ good_examples = []
165
+ for file in os.listdir("./examples"):
166
+ if file.endswith(".wav"):
167
+ good_examples.append([f"./examples/{file}"])
168
+ bad_examples = []
169
+ for file in os.listdir("./bad_examples"):
170
+ if file.endswith(".wav"):
171
+ bad_examples.append([f"./bad_examples/{file}"])
172
+ examples = []
173
+ examples.extend(good_examples)
174
+ examples.extend(bad_examples)
175
+ # with gr.Blocks() as iface:
176
+ # gr.Markdown("# Llama3-S: A Speech & Text Fusion Model Checkpoint from Homebrew")
177
+ # gr.Markdown("Enter text or upload a .wav file to generate text based on its content.")
178
+
179
+ # with gr.Row():
180
+ # input_type = gr.Radio(["text", "audio"], label="Input Type", value="audio")
181
+ # text_input = gr.Textbox(label="Text Input", visible=False)
182
+ # audio_input = gr.Audio(sources=["upload"], type="filepath", label="Upload audio", visible=True)
183
+
184
+ # output = gr.Textbox(label="Generated Text")
185
+
186
+ # submit_button = gr.Button("Submit")
187
+
188
+ # input_type.change(
189
+ # update_visibility,
190
+ # inputs=[input_type],
191
+ # outputs=[text_input, audio_input]
192
+ # )
193
+
194
+ # submit_button.click(
195
+ # process_input,
196
+ # inputs=[input_type, text_input, audio_input],
197
+ # outputs=[output]
198
+ # )
199
+
200
+ # gr.Examples(examples, inputs=[audio_input])
201
+
202
+ # iface.launch(server_name="127.0.0.1", server_port=8080)
203
+ with gr.Blocks() as iface:
204
+ gr.Markdown("# Llama3-1-S: checkpoint Aug 19, 2024")
205
+ gr.Markdown("Enter text to convert to audio, then submit the audio to generate text or Upload Audio")
206
+
207
+ with gr.Row():
208
+ input_type = gr.Radio(["text", "audio"], label="Input Type", value="audio")
209
+ text_input = gr.Textbox(label="Text Input", visible=False)
210
+ audio_input = gr.Audio(label="Audio", type="filepath", visible=True)
211
+ # audio_output = gr.Audio(label="Converted Audio", type="filepath", visible=False)
212
+
213
+ convert_button = gr.Button("Convert to Audio", visible=False)
214
+ submit_button = gr.Button("Submit for Processing")
215
+ transcrip_button = gr.Button("Please Transcribe the audio for me")
216
+
217
+ text_output = gr.Textbox(label="Generated Text")
218
+
219
+ def update_visibility(input_type):
220
+ return (gr.update(visible=input_type == "text"),
221
+ gr.update(visible=input_type == "text"))
222
+ def convert_and_display(text):
223
+ audio_file = text_to_audio_file(text)
224
+ return audio_file
225
+ def process_example(file_path):
226
+ return update_visibility("audio")
227
+ input_type.change(
228
+ update_visibility,
229
+ inputs=[input_type],
230
+ outputs=[text_input, convert_button]
231
+ )
232
+
233
+ convert_button.click(
234
+ convert_and_display,
235
+ inputs=[text_input],
236
+ outputs=[audio_input]
237
+ )
238
+
239
+ submit_button.click(
240
+ process_input,
241
+ inputs=[input_type, text_input, audio_input],
242
+ outputs=[text_output]
243
+ )
244
+ transcrip_button.click(
245
+ process_transcribe_input,
246
+ inputs=[input_type, text_input, audio_input],
247
+ outputs=[text_output]
248
+ )
249
+
250
+ gr.Examples(examples, inputs=[audio_input],outputs=[audio_input], fn=process_example)
251
+ iface.queue()
252
+ iface.launch()
253
+ # launch locally
254
+ # iface.launch(server_name="0.0.0.0")
bad_examples/bad-What-is-Love.wav ADDED
Binary file (41.7 kB). View file
 
bad_examples/bad-who-bears-Obama.wav ADDED
Binary file (64.7 kB). View file
 
examples/Can-you-write-a-registration-letter.wav ADDED
Binary file (109 kB). View file
 
examples/Hello.wav ADDED
Binary file (18.6 kB). View file
 
examples/Who-is-Harry-Potter.wav ADDED
Binary file (62.8 kB). View file
 
examples/Write-an-email.wav ADDED
Binary file (45.5 kB). View file
 
examples/codeapythonscript.wav ADDED
Binary file (61 kB). View file
 
examples/generate_3_questions_you_can_ask_an_interviewer.wav ADDED
Binary file (302 kB). View file
 
examples/story.wav ADDED
Binary file (41.5 kB). View file
 
examples/what-is-the-color-of-the-elephant.wav ADDED
Binary file (107 kB). View file
 
examples/what-is-the-color-of-the-ocean.wav ADDED
Binary file (97.4 kB). View file
 
generate_audio.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+
3
+ from whisperspeech.pipeline import Pipeline
4
+ import argparse
5
+
6
+ def parse_args():
7
+ parser = argparse.ArgumentParser(description="Convert text to audio.")
8
+ parser.add_argument(
9
+ "--text",
10
+ type=str,
11
+ required=True,
12
+ help="The text to convert to audio.",
13
+ )
14
+ return parser.parse_args()
15
+
16
+ def convert_text_to_audio(pipe: Pipeline, text: str):
17
+ """Convert text to audio.
18
+
19
+ Args:
20
+ pipe (Pipeline): The pipeline to use for text-to-speech.
21
+ text (str): The text to convert to audio.
22
+
23
+ Returns:
24
+ torch.Tensor: The generated audio.
25
+ """
26
+ return pipe.generate(text)
27
+
28
+
29
+ def convert_text_to_audio_file(pipe: Pipeline, text: str, output_path: str):
30
+ """Convert text to audio and save it to a file.
31
+
32
+ Args:
33
+ pipe (Pipeline): The pipeline to use for text-to-speech.
34
+ text (str): The text to convert to audio.
35
+ output_path (str): The path to save the audio file.
36
+ """
37
+ pipe.generate_to_file(output_path, text)
38
+
39
+
40
+ class TTSProcessor:
41
+ def __init__(self, device: str):
42
+ """Initialize the TTS Processor with a specified device."""
43
+ self.pipe = Pipeline(
44
+ s2a_ref="collabora/whisperspeech:s2a-q4-tiny-en+pl.model", device=device
45
+ )
46
+
47
+ def get_reference_voice_embedding(self, path: str):
48
+ """Get the reference voice embedding from the given audio file.
49
+
50
+ Args:
51
+ path (str): The path to the audio file.
52
+ Returns:
53
+ torch.Tensor: The reference voice embedding."""
54
+ return self.pipe.extract_spk_emb(path).cpu()
55
+
56
+ def convert_text_to_audio(self, text: str, speaker=None):
57
+ """Convert text to audio.
58
+
59
+ Args:
60
+ text (str): The text to convert to audio.
61
+
62
+ Returns:
63
+ torch.Tensor: The generated audio.
64
+ """
65
+ return self.pipe.generate(text, speaker=speaker)
66
+
67
+ def convert_text_to_audio_file(self, text: str, output_path: str, speaker=None):
68
+ """Convert text to audio and save it to a file.
69
+
70
+ Args:
71
+ text (str): The text to convert to audio.
72
+ output_path (str): The path to save the audio file.
73
+ """
74
+ self.pipe.generate_to_file(output_path, text, speaker=speaker)
75
+ if __name__ == "__main__":
76
+ args = parse_args()
77
+ processor = TTSProcessor("cuda")
78
+ text = args.text
79
+ text = text.lower()
80
+ text_split = "_".join(text.lower().split(" "))
81
+ # remove the last character if it is a period
82
+ if text_split[-1] == ".":
83
+ text_split = text_split[:-1]
84
+ print(text_split)
85
+ path = f"./examples/{text_split}.wav"
86
+ processor.convert_text_to_audio_file(text, path)
87
+
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai-whisper==20231117
2
+ IPython
3
+ peft
4
+ huggingface_hub
5
+ matplotlib
6
+ pyarrow
7
+ datasets
8
+ encodec
9
+ soundfile
10
+ gradio==4.39.0
11
+ transformers
12
+ bitsandbytes
13
+ torchvision
14
+ vector_quantize_pytorch
15
+ webdataset
16
+ git+https://github.com/homebrewltd/WhisperSpeech.git
17
+ --extra-index-url https://download.pytorch.org/whl/cu121
18
+ torch==2.2.0
19
+ torchaudio==2.2.0
user_audio/0bf62a35-94bb-43f0-9a5f-9691c1691859_temp_audio.wav ADDED
Binary file (147 kB). View file
 
whisper-vq-stoks-medium-en+pl-fixed.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee935a1cd19e78900ffbace1c87dd79ab8e9c414bf1d5bd00fd497d82d9b5dba
3
+ size 90919761