bachvudinh commited on
Commit
92a440c
·
1 Parent(s): dbf9701

try to make text to speech work on zero GPU

Browse files
Files changed (1) hide show
  1. app.py +17 -44
app.py CHANGED
@@ -20,10 +20,24 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
20
  vq_model = RQBottleneckTransformer.load_model(
21
  "whisper-vq-stoks-medium-en+pl-fixed.model"
22
  ).to(device)
23
- # vq_model.ensure_whisper(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  @spaces.GPU
26
  def audio_to_sound_tokens_whisperspeech(audio_path):
 
27
  wav, sr = torchaudio.load(audio_path)
28
  if sr != 16000:
29
  wav = torchaudio.functional.resample(wav, sr, 16000)
@@ -36,6 +50,7 @@ def audio_to_sound_tokens_whisperspeech(audio_path):
36
 
37
  @spaces.GPU
38
  def audio_to_sound_tokens_whisperspeech_transcribe(audio_path):
 
39
  wav, sr = torchaudio.load(audio_path)
40
  if sr != 16000:
41
  wav = torchaudio.functional.resample(wav, sr, 16000)
@@ -45,21 +60,6 @@ def audio_to_sound_tokens_whisperspeech_transcribe(audio_path):
45
 
46
  result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
47
  return f'<|reserved_special_token_69|><|sound_start|>{result}<|sound_end|>'
48
-
49
- tts = TTSProcessor(device)
50
- use_8bit = False
51
- llm_path = "homebrewltd/Llama3.1-s-instruct-2024-08-19-epoch-3"
52
- tokenizer = AutoTokenizer.from_pretrained(llm_path)
53
- model_kwargs = {}
54
- if use_8bit:
55
- model_kwargs["quantization_config"] = BitsAndBytesConfig(
56
- load_in_8bit=True,
57
- llm_int8_enable_fp32_cpu_offload=False,
58
- llm_int8_has_fp16_weight=False,
59
- )
60
- else:
61
- model_kwargs["torch_dtype"] = torch.bfloat16
62
- model = AutoModelForCausalLM.from_pretrained(llm_path, **model_kwargs).to(device)
63
  # print(tokenizer.encode("<|sound_0001|>", add_special_tokens=False))# return the audio tensor
64
  # print(tokenizer.eos_token)
65
 
@@ -74,6 +74,7 @@ def text_to_audio_file(text):
74
  # remove the last character if it is a period
75
  if text_split[-1] == ".":
76
  text_split = text_split[:-1]
 
77
  tts.convert_text_to_audio_file(text, temp_file)
78
  # logging.info(f"Saving audio to {temp_file}")
79
  # torchaudio.save(temp_file, audio.cpu(), sample_rate=24000)
@@ -165,34 +166,6 @@ for file in os.listdir("./bad_examples"):
165
  examples = []
166
  examples.extend(good_examples)
167
  examples.extend(bad_examples)
168
- # with gr.Blocks() as iface:
169
- # gr.Markdown("# Llama3-S: A Speech & Text Fusion Model Checkpoint from Homebrew")
170
- # gr.Markdown("Enter text or upload a .wav file to generate text based on its content.")
171
-
172
- # with gr.Row():
173
- # input_type = gr.Radio(["text", "audio"], label="Input Type", value="audio")
174
- # text_input = gr.Textbox(label="Text Input", visible=False)
175
- # audio_input = gr.Audio(sources=["upload"], type="filepath", label="Upload audio", visible=True)
176
-
177
- # output = gr.Textbox(label="Generated Text")
178
-
179
- # submit_button = gr.Button("Submit")
180
-
181
- # input_type.change(
182
- # update_visibility,
183
- # inputs=[input_type],
184
- # outputs=[text_input, audio_input]
185
- # )
186
-
187
- # submit_button.click(
188
- # process_input,
189
- # inputs=[input_type, text_input, audio_input],
190
- # outputs=[output]
191
- # )
192
-
193
- # gr.Examples(examples, inputs=[audio_input])
194
-
195
- # iface.launch(server_name="127.0.0.1", server_port=8080)
196
  with gr.Blocks() as iface:
197
  gr.Markdown("# Llama3-1-S: checkpoint Aug 19, 2024")
198
  gr.Markdown("Enter text to convert to audio, then submit the audio to generate text or Upload Audio")
 
20
  vq_model = RQBottleneckTransformer.load_model(
21
  "whisper-vq-stoks-medium-en+pl-fixed.model"
22
  ).to(device)
23
+ # tts = TTSProcessor('cpu')
24
+ use_8bit = False
25
+ llm_path = "homebrewltd/Llama3.1-s-instruct-2024-08-19-epoch-3"
26
+ tokenizer = AutoTokenizer.from_pretrained(llm_path)
27
+ model_kwargs = {}
28
+ if use_8bit:
29
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
30
+ load_in_8bit=True,
31
+ llm_int8_enable_fp32_cpu_offload=False,
32
+ llm_int8_has_fp16_weight=False,
33
+ )
34
+ else:
35
+ model_kwargs["torch_dtype"] = torch.bfloat16
36
+ model = AutoModelForCausalLM.from_pretrained(llm_path, **model_kwargs).to(device)
37
 
38
  @spaces.GPU
39
  def audio_to_sound_tokens_whisperspeech(audio_path):
40
+ vq_model.ensure_whisper('cuda')
41
  wav, sr = torchaudio.load(audio_path)
42
  if sr != 16000:
43
  wav = torchaudio.functional.resample(wav, sr, 16000)
 
50
 
51
  @spaces.GPU
52
  def audio_to_sound_tokens_whisperspeech_transcribe(audio_path):
53
+ vq_model.ensure_whisper('cuda')
54
  wav, sr = torchaudio.load(audio_path)
55
  if sr != 16000:
56
  wav = torchaudio.functional.resample(wav, sr, 16000)
 
60
 
61
  result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
62
  return f'<|reserved_special_token_69|><|sound_start|>{result}<|sound_end|>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # print(tokenizer.encode("<|sound_0001|>", add_special_tokens=False))# return the audio tensor
64
  # print(tokenizer.eos_token)
65
 
 
74
  # remove the last character if it is a period
75
  if text_split[-1] == ".":
76
  text_split = text_split[:-1]
77
+ tts = TTSProcessor("cuda")
78
  tts.convert_text_to_audio_file(text, temp_file)
79
  # logging.info(f"Saving audio to {temp_file}")
80
  # torchaudio.save(temp_file, audio.cpu(), sample_rate=24000)
 
166
  examples = []
167
  examples.extend(good_examples)
168
  examples.extend(bad_examples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  with gr.Blocks() as iface:
170
  gr.Markdown("# Llama3-1-S: checkpoint Aug 19, 2024")
171
  gr.Markdown("Enter text to convert to audio, then submit the audio to generate text or Upload Audio")