PHBJT commited on
Commit
276c4d0
·
verified ·
1 Parent(s): 6ca328f

Update app.py

Browse files

Updated the model repo_id and removed the large option.

Files changed (1) hide show
  1. app.py +6 -14
app.py CHANGED
@@ -12,11 +12,9 @@ from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
12
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
 
14
 
15
- repo_id = "parler-tts/parler-tts-mini-v1"
16
- repo_id_large = "ylacombe/parler-large-v1-og"
17
 
18
  model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
19
- model_large = ParlerTTSForConditionalGeneration.from_pretrained(repo_id_large).to(device)
20
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
21
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
22
 
@@ -76,19 +74,14 @@ def preprocess(text):
76
  return text
77
 
78
  @spaces.GPU
79
- def gen_tts(text, description, use_large=False):
80
  inputs = tokenizer(description.strip(), return_tensors="pt").to(device)
81
  prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)
82
 
83
  set_seed(SEED)
84
- if use_large:
85
- generation = model_large.generate(
86
- input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0
87
- )
88
- else:
89
- generation = model.generate(
90
- input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0
91
- )
92
  audio_arr = generation.cpu().numpy().squeeze()
93
 
94
  return SAMPLE_RATE, audio_arr
@@ -163,12 +156,11 @@ with gr.Blocks(css=css) as block:
163
  with gr.Column():
164
  input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
165
  description = gr.Textbox(label="Description", lines=2, value=default_description, elem_id="input_description")
166
- use_large = gr.Checkbox(value=False, label="Use Large checkpoint", info="Generate with Parler-TTS Large v1 instead of Mini v1 - Better but way slower.")
167
  run_button = gr.Button("Generate Audio", variant="primary")
168
  with gr.Column():
169
  audio_out = gr.Audio(label="Parler-TTS generation", type="numpy", elem_id="audio_out")
170
 
171
- inputs = [input_text, description, use_large]
172
  outputs = [audio_out]
173
  run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True)
174
  gr.Examples(examples=examples, fn=gen_tts, inputs=inputs, outputs=outputs, cache_examples=True)
 
12
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
 
14
 
15
+ repo_id = "PHBJT/parler_french_tts_mini_v0.1"
 
16
 
17
  model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
 
18
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
19
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
20
 
 
74
  return text
75
 
76
  @spaces.GPU
77
+ def gen_tts(text, description):
78
  inputs = tokenizer(description.strip(), return_tensors="pt").to(device)
79
  prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)
80
 
81
  set_seed(SEED)
82
+ generation = model.generate(
83
+ input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0
84
+ )
 
 
 
 
 
85
  audio_arr = generation.cpu().numpy().squeeze()
86
 
87
  return SAMPLE_RATE, audio_arr
 
156
  with gr.Column():
157
  input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
158
  description = gr.Textbox(label="Description", lines=2, value=default_description, elem_id="input_description")
 
159
  run_button = gr.Button("Generate Audio", variant="primary")
160
  with gr.Column():
161
  audio_out = gr.Audio(label="Parler-TTS generation", type="numpy", elem_id="audio_out")
162
 
163
+ inputs = [input_text, description
164
  outputs = [audio_out]
165
  run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True)
166
  gr.Examples(examples=examples, fn=gen_tts, inputs=inputs, outputs=outputs, cache_examples=True)