Spaces:
Running
on
Zero
Running
on
Zero
Update TTSModelV1 to use v1 voices from Kokoro-82M repository
Browse files- tts_model_v1.py +20 -17
tts_model_v1.py
CHANGED
@@ -6,6 +6,7 @@ from typing import Tuple, List
|
|
6 |
import soundfile as sf
|
7 |
from kokoro import KPipeline
|
8 |
import spaces
|
|
|
9 |
|
10 |
class TTSModelV1:
|
11 |
"""KPipeline-based TTS model for v1.0.0"""
|
@@ -13,7 +14,8 @@ class TTSModelV1:
|
|
13 |
def __init__(self):
|
14 |
self.pipeline = None
|
15 |
self.model_repo = "hexgrad/Kokoro-82M"
|
16 |
-
|
|
|
17 |
|
18 |
def initialize(self) -> bool:
|
19 |
"""Initialize KPipeline and verify voices"""
|
@@ -22,9 +24,11 @@ class TTSModelV1:
|
|
22 |
|
23 |
self.pipeline = None # cannot be initialized outside of GPU decorator
|
24 |
|
25 |
-
#
|
26 |
-
|
27 |
-
|
|
|
|
|
28 |
|
29 |
# Verify voices were downloaded successfully
|
30 |
available_voices = self.list_voices()
|
@@ -43,8 +47,9 @@ class TTSModelV1:
|
|
43 |
def list_voices(self) -> List[str]:
|
44 |
"""List available voices"""
|
45 |
voices = []
|
46 |
-
|
47 |
-
|
|
|
48 |
if file.endswith(".pt"):
|
49 |
voice_name = file[:-3]
|
50 |
voices.append(voice_name)
|
@@ -76,7 +81,7 @@ class TTSModelV1:
|
|
76 |
t_voices = []
|
77 |
for voice in voice_names:
|
78 |
try:
|
79 |
-
voice_path = os.path.join(self.voices_dir, f"{voice}.pt")
|
80 |
try:
|
81 |
voicepack = torch.load(voice_path, weights_only=True)
|
82 |
except Exception as e:
|
@@ -90,18 +95,16 @@ class TTSModelV1:
|
|
90 |
voicepack = torch.mean(torch.stack(t_voices), dim=0)
|
91 |
voice_name = "_".join(voice_names)
|
92 |
# Save mixed voice temporarily
|
93 |
-
mixed_voice_path = os.path.join(self.voices_dir, f"{voice_name}.pt")
|
94 |
torch.save(voicepack, mixed_voice_path)
|
95 |
else:
|
96 |
voice_name = voice_names[0]
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
split_pattern=r'\n+' # Default chunking pattern
|
104 |
-
)
|
105 |
|
106 |
# Initialize tracking
|
107 |
audio_chunks = []
|
@@ -114,7 +117,7 @@ class TTSModelV1:
|
|
114 |
text,
|
115 |
voice=voice_name,
|
116 |
speed=speed,
|
117 |
-
split_pattern=r'\n+'
|
118 |
)
|
119 |
|
120 |
# Process chunks
|
|
|
6 |
import soundfile as sf
|
7 |
from kokoro import KPipeline
|
8 |
import spaces
|
9 |
+
from lib.file_utils import download_voice_files, ensure_dir
|
10 |
|
11 |
class TTSModelV1:
|
12 |
"""KPipeline-based TTS model for v1.0.0"""
|
|
|
14 |
def __init__(self):
|
15 |
self.pipeline = None
|
16 |
self.model_repo = "hexgrad/Kokoro-82M"
|
17 |
+
# Use v1 voices from Kokoro-82M repo
|
18 |
+
self.voices_dir = os.path.join(os.path.dirname(__file__), "voices")
|
19 |
|
20 |
def initialize(self) -> bool:
|
21 |
"""Initialize KPipeline and verify voices"""
|
|
|
24 |
|
25 |
self.pipeline = None # cannot be initialized outside of GPU decorator
|
26 |
|
27 |
+
# Download v1 voices if needed
|
28 |
+
ensure_dir(self.voices_dir)
|
29 |
+
if not os.path.exists(os.path.join(self.voices_dir, "voices")):
|
30 |
+
print("Downloading v1 voices...")
|
31 |
+
download_voice_files(self.model_repo, "voices", self.voices_dir)
|
32 |
|
33 |
# Verify voices were downloaded successfully
|
34 |
available_voices = self.list_voices()
|
|
|
47 |
def list_voices(self) -> List[str]:
|
48 |
"""List available voices"""
|
49 |
voices = []
|
50 |
+
voices_dir = os.path.join(self.voices_dir, "voices")
|
51 |
+
if os.path.exists(voices_dir):
|
52 |
+
for file in os.listdir(voices_dir):
|
53 |
if file.endswith(".pt"):
|
54 |
voice_name = file[:-3]
|
55 |
voices.append(voice_name)
|
|
|
81 |
t_voices = []
|
82 |
for voice in voice_names:
|
83 |
try:
|
84 |
+
voice_path = os.path.join(self.voices_dir, "voices", f"{voice}.pt")
|
85 |
try:
|
86 |
voicepack = torch.load(voice_path, weights_only=True)
|
87 |
except Exception as e:
|
|
|
95 |
voicepack = torch.mean(torch.stack(t_voices), dim=0)
|
96 |
voice_name = "_".join(voice_names)
|
97 |
# Save mixed voice temporarily
|
98 |
+
mixed_voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
|
99 |
torch.save(voicepack, mixed_voice_path)
|
100 |
else:
|
101 |
voice_name = voice_names[0]
|
102 |
+
voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
|
103 |
+
try:
|
104 |
+
voicepack = torch.load(voice_path, weights_only=True)
|
105 |
+
except Exception as e:
|
106 |
+
print(f"Warning: weights_only load failed, attempting full load: {str(e)}")
|
107 |
+
voicepack = torch.load(voice_path, weights_only=False)
|
|
|
|
|
108 |
|
109 |
# Initialize tracking
|
110 |
audio_chunks = []
|
|
|
117 |
text,
|
118 |
voice=voice_name,
|
119 |
speed=speed,
|
120 |
+
split_pattern=r'\n+' # Default chunking pattern
|
121 |
)
|
122 |
|
123 |
# Process chunks
|