Remsky commited on
Commit
bb43905
·
1 Parent(s): 372ebd3

Update TTSModelV1 to use v1 voices from Kokoro-82M repository

Browse files
Files changed (1) hide show
  1. tts_model_v1.py +20 -17
tts_model_v1.py CHANGED
@@ -6,6 +6,7 @@ from typing import Tuple, List
6
  import soundfile as sf
7
  from kokoro import KPipeline
8
  import spaces
 
9
 
10
  class TTSModelV1:
11
  """KPipeline-based TTS model for v1.0.0"""
@@ -13,7 +14,8 @@ class TTSModelV1:
13
  def __init__(self):
14
  self.pipeline = None
15
  self.model_repo = "hexgrad/Kokoro-82M"
16
- self.voices_dir = os.path.join(os.path.dirname(__file__), "reference", "reference_other_repo", "voices")
 
17
 
18
  def initialize(self) -> bool:
19
  """Initialize KPipeline and verify voices"""
@@ -22,9 +24,11 @@ class TTSModelV1:
22
 
23
  self.pipeline = None # cannot be initialized outside of GPU decorator
24
 
25
- # Verify voices directory exists
26
- if not os.path.exists(self.voices_dir):
27
- raise ValueError(f"Voice files not found at {self.voices_dir}")
 
 
28
 
29
  # Verify voices were downloaded successfully
30
  available_voices = self.list_voices()
@@ -43,8 +47,9 @@ class TTSModelV1:
43
  def list_voices(self) -> List[str]:
44
  """List available voices"""
45
  voices = []
46
- if os.path.exists(self.voices_dir):
47
- for file in os.listdir(self.voices_dir):
 
48
  if file.endswith(".pt"):
49
  voice_name = file[:-3]
50
  voices.append(voice_name)
@@ -76,7 +81,7 @@ class TTSModelV1:
76
  t_voices = []
77
  for voice in voice_names:
78
  try:
79
- voice_path = os.path.join(self.voices_dir, f"{voice}.pt")
80
  try:
81
  voicepack = torch.load(voice_path, weights_only=True)
82
  except Exception as e:
@@ -90,18 +95,16 @@ class TTSModelV1:
90
  voicepack = torch.mean(torch.stack(t_voices), dim=0)
91
  voice_name = "_".join(voice_names)
92
  # Save mixed voice temporarily
93
- mixed_voice_path = os.path.join(self.voices_dir, f"{voice_name}.pt")
94
  torch.save(voicepack, mixed_voice_path)
95
  else:
96
  voice_name = voice_names[0]
97
-
98
- # Generate speech using KPipeline
99
- generator = self.pipeline(
100
- text,
101
- voice=voice_name,
102
- speed=speed,
103
- split_pattern=r'\n+' # Default chunking pattern
104
- )
105
 
106
  # Initialize tracking
107
  audio_chunks = []
@@ -114,7 +117,7 @@ class TTSModelV1:
114
  text,
115
  voice=voice_name,
116
  speed=speed,
117
- split_pattern=r'\n+'
118
  )
119
 
120
  # Process chunks
 
6
  import soundfile as sf
7
  from kokoro import KPipeline
8
  import spaces
9
+ from lib.file_utils import download_voice_files, ensure_dir
10
 
11
  class TTSModelV1:
12
  """KPipeline-based TTS model for v1.0.0"""
 
14
  def __init__(self):
15
  self.pipeline = None
16
  self.model_repo = "hexgrad/Kokoro-82M"
17
+ # Use v1 voices from Kokoro-82M repo
18
+ self.voices_dir = os.path.join(os.path.dirname(__file__), "voices")
19
 
20
  def initialize(self) -> bool:
21
  """Initialize KPipeline and verify voices"""
 
24
 
25
  self.pipeline = None # cannot be initialized outside of GPU decorator
26
 
27
+ # Download v1 voices if needed
28
+ ensure_dir(self.voices_dir)
29
+ if not os.path.exists(os.path.join(self.voices_dir, "voices")):
30
+ print("Downloading v1 voices...")
31
+ download_voice_files(self.model_repo, "voices", self.voices_dir)
32
 
33
  # Verify voices were downloaded successfully
34
  available_voices = self.list_voices()
 
47
  def list_voices(self) -> List[str]:
48
  """List available voices"""
49
  voices = []
50
+ voices_dir = os.path.join(self.voices_dir, "voices")
51
+ if os.path.exists(voices_dir):
52
+ for file in os.listdir(voices_dir):
53
  if file.endswith(".pt"):
54
  voice_name = file[:-3]
55
  voices.append(voice_name)
 
81
  t_voices = []
82
  for voice in voice_names:
83
  try:
84
+ voice_path = os.path.join(self.voices_dir, "voices", f"{voice}.pt")
85
  try:
86
  voicepack = torch.load(voice_path, weights_only=True)
87
  except Exception as e:
 
95
  voicepack = torch.mean(torch.stack(t_voices), dim=0)
96
  voice_name = "_".join(voice_names)
97
  # Save mixed voice temporarily
98
+ mixed_voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
99
  torch.save(voicepack, mixed_voice_path)
100
  else:
101
  voice_name = voice_names[0]
102
+ voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
103
+ try:
104
+ voicepack = torch.load(voice_path, weights_only=True)
105
+ except Exception as e:
106
+ print(f"Warning: weights_only load failed, attempting full load: {str(e)}")
107
+ voicepack = torch.load(voice_path, weights_only=False)
 
 
108
 
109
  # Initialize tracking
110
  audio_chunks = []
 
117
  text,
118
  voice=voice_name,
119
  speed=speed,
120
+ split_pattern=r'\n+' # Default chunking pattern
121
  )
122
 
123
  # Process chunks