File size: 6,677 Bytes
27f8803
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import torch
import numpy as np
import time
from typing import Tuple, List
import soundfile as sf
from kokoro import KPipeline
import spaces

class TTSModelV1:
    """KPipeline-based TTS model for v1.0.0"""
    
    def __init__(self):
        self.pipeline = None
        self.voices_dir = "voices"
        self.model_repo = "hexgrad/Kokoro-82M"
        
    def initialize(self) -> bool:
        """Initialize KPipeline and verify voices"""
        try:
            print("Initializing v1.0.0 model...")
            
            # Initialize KPipeline with American English
            self.pipeline = KPipeline(lang_code='a')
            
            # Verify local voice files are available
            voices_dir = os.path.join(self.voices_dir, "voices")
            if not os.path.exists(voices_dir):
                raise ValueError("Voice files not found")

            # Verify voices were downloaded successfully
            available_voices = self.list_voices()
            if not available_voices:
                print("Warning: No voices found after initialization")
            else:
                print(f"Found {len(available_voices)} voices")
            
            print("Model initialization complete")
            return True
            
        except Exception as e:
            print(f"Error initializing model: {str(e)}")
            return False
    
    def list_voices(self) -> List[str]:
        """List available voices"""
        voices = []
        voices_subdir = os.path.join(self.voices_dir, "voices")
        if os.path.exists(voices_subdir):
            for file in os.listdir(voices_subdir):
                if file.endswith(".pt"):
                    voice_name = file[:-3]
                    voices.append(voice_name)
        return voices
        
    @spaces.GPU(duration=None)  # Duration will be set by the UI
    def generate_speech(self, text: str, voice_names: list[str], speed: float = 1.0, gpu_timeout: int = 60, progress_callback=None, progress_state=None, progress=None) -> Tuple[np.ndarray, float]:
        """Generate speech from text using KPipeline
        
        Args:
            text: Input text to convert to speech
            voice_names: List of voice names to use (will be mixed if multiple)
            speed: Speech speed multiplier
            progress_callback: Optional callback function
            progress_state: Dictionary tracking generation progress metrics
            progress: Progress callback from Gradio
        """
        try:
            start_time = time.time()
            
            if not text or not voice_names:
                raise ValueError("Text and voice name are required")
            
            # Handle voice mixing
            if isinstance(voice_names, list) and len(voice_names) > 1:
                t_voices = []
                for voice in voice_names:
                    try:
                        voice_path = os.path.join(self.voices_dir, "voices", f"{voice}.pt")
                        try:
                            voicepack = torch.load(voice_path, weights_only=True)
                        except Exception as e:
                            print(f"Warning: weights_only load failed, attempting full load: {str(e)}")
                            voicepack = torch.load(voice_path, weights_only=False)
                        t_voices.append(voicepack)
                    except Exception as e:
                        print(f"Warning: Failed to load voice {voice}: {str(e)}")
                
                # Combine voices by taking mean
                voicepack = torch.mean(torch.stack(t_voices), dim=0)
                voice_name = "_".join(voice_names)
                # Save mixed voice temporarily
                mixed_voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
                torch.save(voicepack, mixed_voice_path)
            else:
                voice_name = voice_names[0]
            
            # Generate speech using KPipeline
            generator = self.pipeline(
                text,
                voice=voice_name,
                speed=speed,
                split_pattern=r'\n+'  # Default chunking pattern
            )
            
            # Process chunks and collect metrics
            audio_chunks = []
            chunk_times = []
            chunk_sizes = []
            total_tokens = 0
            
            for i, (gs, ps, audio) in enumerate(generator):
                chunk_start = time.time()
                
                # Store chunk audio
                audio_chunks.append(audio)
                
                # Calculate metrics
                chunk_time = time.time() - chunk_start
                chunk_times.append(chunk_time)
                chunk_sizes.append(len(gs))  # Use grapheme length as chunk size
                
                # Update progress if callback provided
                if progress_callback:
                    chunk_duration = len(audio) / 24000
                    rtf = chunk_time / chunk_duration
                    progress_callback(
                        i + 1,
                        -1,  # Total chunks unknown with generator
                        len(gs) / chunk_time,  # tokens/sec
                        rtf,
                        progress_state,
                        start_time,
                        gpu_timeout,
                        progress
                    )
                
                print(f"Chunk {i+1} processed in {chunk_time:.2f}s")
                print(f"Graphemes: {gs}")
                print(f"Phonemes: {ps}")
            
            # Concatenate audio chunks
            audio = np.concatenate(audio_chunks)
            
            # Cleanup temporary mixed voice if created
            if len(voice_names) > 1:
                try:
                    os.remove(mixed_voice_path)
                except:
                    pass
            
            # Return audio and metrics
            return (
                audio,
                len(audio) / 24000,
                {
                    "chunk_times": chunk_times,
                    "chunk_sizes": chunk_sizes,
                    "tokens_per_sec": [float(x) for x in progress_state["tokens_per_sec"]] if progress_state else [],
                    "rtf": [float(x) for x in progress_state["rtf"]] if progress_state else [],
                    "total_tokens": total_tokens,
                    "total_time": time.time() - start_time
                }
            )
            
        except Exception as e:
            print(f"Error generating speech: {str(e)}")
            raise