Kokoro-TTS-Zero / tts_model_v1.py
Remsky's picture
Add v1.0.0 model support with KPipeline implementation
27f8803
raw
history blame
6.68 kB
import os
import torch
import numpy as np
import time
from typing import Tuple, List
import soundfile as sf
from kokoro import KPipeline
import spaces
class TTSModelV1:
"""KPipeline-based TTS model for v1.0.0"""
def __init__(self):
self.pipeline = None
self.voices_dir = "voices"
self.model_repo = "hexgrad/Kokoro-82M"
def initialize(self) -> bool:
"""Initialize KPipeline and verify voices"""
try:
print("Initializing v1.0.0 model...")
# Initialize KPipeline with American English
self.pipeline = KPipeline(lang_code='a')
# Verify local voice files are available
voices_dir = os.path.join(self.voices_dir, "voices")
if not os.path.exists(voices_dir):
raise ValueError("Voice files not found")
# Verify voices were downloaded successfully
available_voices = self.list_voices()
if not available_voices:
print("Warning: No voices found after initialization")
else:
print(f"Found {len(available_voices)} voices")
print("Model initialization complete")
return True
except Exception as e:
print(f"Error initializing model: {str(e)}")
return False
def list_voices(self) -> List[str]:
"""List available voices"""
voices = []
voices_subdir = os.path.join(self.voices_dir, "voices")
if os.path.exists(voices_subdir):
for file in os.listdir(voices_subdir):
if file.endswith(".pt"):
voice_name = file[:-3]
voices.append(voice_name)
return voices
@spaces.GPU(duration=None) # Duration will be set by the UI
def generate_speech(self, text: str, voice_names: list[str], speed: float = 1.0, gpu_timeout: int = 60, progress_callback=None, progress_state=None, progress=None) -> Tuple[np.ndarray, float]:
"""Generate speech from text using KPipeline
Args:
text: Input text to convert to speech
voice_names: List of voice names to use (will be mixed if multiple)
speed: Speech speed multiplier
progress_callback: Optional callback function
progress_state: Dictionary tracking generation progress metrics
progress: Progress callback from Gradio
"""
try:
start_time = time.time()
if not text or not voice_names:
raise ValueError("Text and voice name are required")
# Handle voice mixing
if isinstance(voice_names, list) and len(voice_names) > 1:
t_voices = []
for voice in voice_names:
try:
voice_path = os.path.join(self.voices_dir, "voices", f"{voice}.pt")
try:
voicepack = torch.load(voice_path, weights_only=True)
except Exception as e:
print(f"Warning: weights_only load failed, attempting full load: {str(e)}")
voicepack = torch.load(voice_path, weights_only=False)
t_voices.append(voicepack)
except Exception as e:
print(f"Warning: Failed to load voice {voice}: {str(e)}")
# Combine voices by taking mean
voicepack = torch.mean(torch.stack(t_voices), dim=0)
voice_name = "_".join(voice_names)
# Save mixed voice temporarily
mixed_voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
torch.save(voicepack, mixed_voice_path)
else:
voice_name = voice_names[0]
# Generate speech using KPipeline
generator = self.pipeline(
text,
voice=voice_name,
speed=speed,
split_pattern=r'\n+' # Default chunking pattern
)
# Process chunks and collect metrics
audio_chunks = []
chunk_times = []
chunk_sizes = []
total_tokens = 0
for i, (gs, ps, audio) in enumerate(generator):
chunk_start = time.time()
# Store chunk audio
audio_chunks.append(audio)
# Calculate metrics
chunk_time = time.time() - chunk_start
chunk_times.append(chunk_time)
chunk_sizes.append(len(gs)) # Use grapheme length as chunk size
# Update progress if callback provided
if progress_callback:
chunk_duration = len(audio) / 24000
rtf = chunk_time / chunk_duration
progress_callback(
i + 1,
-1, # Total chunks unknown with generator
len(gs) / chunk_time, # tokens/sec
rtf,
progress_state,
start_time,
gpu_timeout,
progress
)
print(f"Chunk {i+1} processed in {chunk_time:.2f}s")
print(f"Graphemes: {gs}")
print(f"Phonemes: {ps}")
# Concatenate audio chunks
audio = np.concatenate(audio_chunks)
# Cleanup temporary mixed voice if created
if len(voice_names) > 1:
try:
os.remove(mixed_voice_path)
except:
pass
# Return audio and metrics
return (
audio,
len(audio) / 24000,
{
"chunk_times": chunk_times,
"chunk_sizes": chunk_sizes,
"tokens_per_sec": [float(x) for x in progress_state["tokens_per_sec"]] if progress_state else [],
"rtf": [float(x) for x in progress_state["rtf"]] if progress_state else [],
"total_tokens": total_tokens,
"total_time": time.time() - start_time
}
)
except Exception as e:
print(f"Error generating speech: {str(e)}")
raise