File size: 4,259 Bytes
c690ade 78e760c b3d61a3 78e760c e3db752 b3d61a3 e3db752 78e760c b3d61a3 e3db752 b3d61a3 78e760c b3d61a3 c690ade 364da54 c690ade 78e760c b3d61a3 78e760c b3d61a3 78e760c b3d61a3 b945617 b3d61a3 b945617 b3d61a3 b945617 b3d61a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import copy
from collections import namedtuple
import soundfile as sf
import torch
from loguru import logger
from parler_tts import ParlerTTSForConditionalGeneration
from replicate import Client
from transformers import AutoTokenizer
from kitt.skills.common import config
replicate = Client(api_token=config.REPLICATE_API_KEY)
Voice = namedtuple("voice", ["name", "neutral", "angry", "speed"])
voices_replicate = [
Voice(
"Fast",
neutral="empty",
angry=None,
speed=1.0,
),
Voice(
"Attenborough",
neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/attenborough-neutral.wav",
angry=None,
speed=1.2,
),
Voice(
"Rick",
neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/rick-neutral.wav",
angry=None,
speed=1.2,
),
Voice(
"Freeman",
neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/freeman-neutral.wav",
angry="https://zebel.ams3.digitaloceanspaces.com/xtts/short/freeman-angry.wav",
speed=1.1,
),
Voice(
"Walken",
neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/walken-neutral.wav",
angry=None,
speed=1.1,
),
Voice(
"Darth Wader",
neutral="https://zebel.ams3.digitaloceanspaces.com/xtts/short/darth-neutral.wav",
angry=None,
speed=1.15,
),
]
def prep_for_tts(text: str):
text_tts = copy.deepcopy(text)
text_tts = text_tts.replace("km/h", " kilometers per hour")
text_tts = text_tts.replace("°C", " degree Celsius")
text_tts = text_tts.replace("°F", " degree Fahrenheit")
text_tts = text_tts.replace("km", " kilometers")
return text_tts
def voice_from_text(voice, voices):
for v in voices:
if voice == f"{v.name} - Neutral":
return v.neutral
if voice == f"{v.name} - Angry":
return v.angry
raise ValueError(f"Voice {voice} not found.")
def speed_from_text(voice, voices):
for v in voices:
if voice == f"{v.name} - Neutral":
return v.speed
if voice == f"{v.name} - Angry":
return v.speed
def run_tts_replicate(text: str, voice_character: str):
voice = voice_from_text(voice_character, voices_replicate)
input = {"text": text, "speaker": voice, "cleanup_voice": True}
output = replicate.run(
# "afiaka87/tortoise-tts:e9658de4b325863c4fcdc12d94bb7c9b54cbfe351b7ca1b36860008172b91c71",
"lucataco/xtts-v2:684bc3855b37866c0c65add2ff39c78f3dea3f4ff103a436465326e0f438d55e",
input=input,
)
logger.info(f"sound output: {output}")
return output
def get_fast_tts():
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = ParlerTTSForConditionalGeneration.from_pretrained(
"parler-tts/parler-tts-mini-expresso"
).to(device)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")
return model, tokenizer, device
fast_tts = get_fast_tts()
def run_tts_fast(text: str):
model, tokenizer, device = fast_tts
description = "Thomas speaks moderately slowly in a sad tone with emphasis and high quality audio."
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()
return (model.config.sampling_rate, audio_arr), dict(text=text, voice="Thomas")
def load_melo_tts():
from melo.api import TTS as MeloTTS
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MeloTTS(language="EN", device=device)
return model
try:
melo_tts = load_melo_tts()
except ImportError as e:
logger.error(f"Error loading MeloTTS: {e}")
melo_tts = None
def run_melo_tts(text: str, voice: str):
if melo_tts is None:
raise ValueError("MeloTTS not available.")
speed = 1.0
speaker_ids = melo_tts.hps.data.spk2id
audio = melo_tts.tts_to_file(text, speaker_ids["EN-Default"], None, speed=speed)
return melo_tts.hps.data.sampling_rate, audio
|