Ichigo-llama3.1-s-v0.4 / generate_audio.py
bachvudinh's picture
initial commit
3c72012
raw
history blame
2.68 kB
import torchaudio
from whisperspeech.pipeline import Pipeline
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Convert text to audio.")
parser.add_argument(
"--text",
type=str,
required=True,
help="The text to convert to audio.",
)
return parser.parse_args()
def convert_text_to_audio(pipe: Pipeline, text: str):
"""Convert text to audio.
Args:
pipe (Pipeline): The pipeline to use for text-to-speech.
text (str): The text to convert to audio.
Returns:
torch.Tensor: The generated audio.
"""
return pipe.generate(text)
def convert_text_to_audio_file(pipe: Pipeline, text: str, output_path: str):
"""Convert text to audio and save it to a file.
Args:
pipe (Pipeline): The pipeline to use for text-to-speech.
text (str): The text to convert to audio.
output_path (str): The path to save the audio file.
"""
pipe.generate_to_file(output_path, text)
class TTSProcessor:
def __init__(self, device: str):
"""Initialize the TTS Processor with a specified device."""
self.pipe = Pipeline(
s2a_ref="collabora/whisperspeech:s2a-q4-tiny-en+pl.model", device=device
)
def get_reference_voice_embedding(self, path: str):
"""Get the reference voice embedding from the given audio file.
Args:
path (str): The path to the audio file.
Returns:
torch.Tensor: The reference voice embedding."""
return self.pipe.extract_spk_emb(path).cpu()
def convert_text_to_audio(self, text: str, speaker=None):
"""Convert text to audio.
Args:
text (str): The text to convert to audio.
Returns:
torch.Tensor: The generated audio.
"""
return self.pipe.generate(text, speaker=speaker)
def convert_text_to_audio_file(self, text: str, output_path: str, speaker=None):
"""Convert text to audio and save it to a file.
Args:
text (str): The text to convert to audio.
output_path (str): The path to save the audio file.
"""
self.pipe.generate_to_file(output_path, text, speaker=speaker)
if __name__ == "__main__":
args = parse_args()
processor = TTSProcessor("cuda")
text = args.text
text = text.lower()
text_split = "_".join(text.lower().split(" "))
# remove the last character if it is a period
if text_split[-1] == ".":
text_split = text_split[:-1]
print(text_split)
path = f"./examples/{text_split}.wav"
processor.convert_text_to_audio_file(text, path)