File size: 9,428 Bytes
0d48494 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2Model
from torchaudio.transforms import MelSpectrogram, InverseMelScale, GriffinLim
import torchaudio
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import GradScaler, autocast
class TextToSpeechDataset(Dataset):
def __init__(self, text_files, audio_files, tokenizer, mel_transform, max_length=512):
self.text_files = text_files
self.audio_files = audio_files
self.tokenizer = tokenizer
self.mel_transform = mel_transform
self.max_length = max_length
def __len__(self):
return len(self.text_files)
def __getitem__(self, idx):
# Load text
with open(self.text_files[idx], 'r') as f:
text = f.read().strip()
# Tokenize text
text_tokens = self.tokenizer.encode(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors="pt"
).squeeze(0)
# Load audio and convert to mel spectrogram
waveform, sample_rate = torchaudio.load(self.audio_files[idx])
mel_spec = self.mel_transform(waveform)
return text_tokens, mel_spec.squeeze(0)
def collate_fn(batch):
text_tokens, mel_specs = zip(*batch)
# Pad text tokens
max_text_len = max(tokens.size(0) for tokens in text_tokens)
text_tokens_padded = torch.stack([
torch.cat([tokens, torch.zeros(max_text_len - tokens.size(0), dtype=tokens.dtype)], dim=0)
if tokens.size(0) < max_text_len
else tokens[:max_text_len]
for tokens in text_tokens
])
# Pad mel spectrograms
max_mel_len = max(spec.size(1) for spec in mel_specs)
mel_specs_padded = torch.stack([
F.pad(spec, (0, max_mel_len - spec.size(1)))
if spec.size(1) < max_mel_len
else spec[:, :max_mel_len]
for spec in mel_specs
])
return text_tokens_padded, mel_specs_padded
class VAEDecoder(nn.Module):
def __init__(self, latent_dim, mel_channels=80):
super().__init__()
# Encoder part (probabilistic)
self.fc_mu = nn.Linear(latent_dim, latent_dim)
self.fc_var = nn.Linear(latent_dim, latent_dim)
# Decoder part
self.decoder_layers = nn.Sequential(
nn.Linear(latent_dim, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, mel_channels * 80), # Output mel spectrogram
nn.Unflatten(1, (mel_channels, 80))
)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, z):
mu = self.fc_mu(z)
log_var = self.fc_var(z)
# Reparameterization trick
z = self.reparameterize(mu, log_var)
# Decode
mel_spec = self.decoder_layers(z)
return mel_spec, mu, log_var
class TextToSpeechModel(nn.Module):
def __init__(self, text_encoder, vae_decoder, latent_dim=256):
super().__init__()
self.text_encoder = text_encoder
self.vae_decoder = vae_decoder
# Projection layer to map encoder output to latent space
self.projection = nn.Linear(text_encoder.config.hidden_size, latent_dim)
def forward(self, text_tokens):
# Encode text
encoder_output = self.text_encoder(text_tokens).last_hidden_state
# Mean pooling of encoder output
text_embedding = encoder_output.mean(dim=1)
# Project to latent space
latent_z = self.projection(text_embedding)
# Decode to mel spectrogram
mel_spec, mu, log_var = self.vae_decoder(latent_z)
return mel_spec, mu, log_var
def vae_loss(reconstruction, target, mu, log_var):
# Reconstruction loss (MSE)
recon_loss = F.mse_loss(reconstruction, target, reduction='mean')
# KL Divergence loss
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return recon_loss + 0.001 * kl_loss
def train_model(num_epochs=10, accumulation_steps=16):
# Tokenizer and mel spectrogram transform
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
# Mel spectrogram configuration
mel_transform = MelSpectrogram(
sample_rate=16000,
n_mels=80,
n_fft=1024,
hop_length=256
)
# Data preparation
text_folder = './texts'
audio_folder = './audio'
# Load text and audio files
text_files = [os.path.join(text_folder, f) for f in os.listdir(text_folder) if f.endswith('.txt')]
audio_files = [os.path.join(audio_folder, f) for f in os.listdir(audio_folder) if f.endswith('.wav')]
# Split dataset
train_texts, val_texts, train_audios, val_audios = train_test_split(
text_files, audio_files, test_size=0.1, random_state=42
)
# Create datasets and dataloaders
train_dataset = TextToSpeechDataset(train_texts, train_audios, tokenizer, mel_transform)
val_dataset = TextToSpeechDataset(val_texts, val_audios, tokenizer, mel_transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)
# Model components
text_encoder = GPT2Model.from_pretrained('gpt2')
vae_decoder = VAEDecoder(latent_dim=256)
# Combine into full model
model = TextToSpeechModel(text_encoder, vae_decoder)
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
# Gradient scaler
scaler = GradScaler()
best_val_loss = float('inf')
# Training loop
for epoch in range(num_epochs):
model.train()
train_loss = 0
for batch_idx, (text_tokens, mel_specs) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
text_tokens = text_tokens.to(device)
mel_specs = mel_specs.to(device)
with autocast(dtype=torch.float16, device_type='cuda'):
# Forward pass
reconstructed_mel, mu, log_var = model(text_tokens)
# Compute loss
loss = vae_loss(reconstructed_mel, mel_specs, mu, log_var)
# Scaled loss and backpropagation
loss = loss / accumulation_steps
scaler.scale(loss).backward()
if (batch_idx + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
train_loss += loss.item()
# Validation
model.eval()
val_loss = 0
with torch.no_grad():
for text_tokens, mel_specs in tqdm(val_loader, desc=f"Validation {epoch+1}"):
text_tokens = text_tokens.to(device)
mel_specs = mel_specs.to(device)
reconstructed_mel, mu, log_var = model(text_tokens)
loss = vae_loss(reconstructed_mel, mel_specs, mu, log_var)
val_loss += loss.item()
# Scheduler step
scheduler.step()
# Print epoch summary
print(f'Epoch {epoch+1}: Train Loss: {train_loss/len(train_loader)}, Val Loss: {val_loss/len(val_loader)}')
# Model saving
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), 'best_tts_model.pth')
return model
# Run training
trained_model = train_model()
# Optional: Inference function for generating mel spectrograms
def generate_mel_spectrogram(text, model, tokenizer, device):
model.eval()
with torch.no_grad():
# Tokenize input text
text_tokens = tokenizer.encode(
text,
return_tensors="pt",
truncation=True,
padding='max_length',
max_length=512
).to(device)
# Generate mel spectrogram
mel_spec, _, _ = model(text_tokens)
return mel_spec
# Optional: Convert mel spectrogram back to audio
def mel_to_audio(mel_spec, sample_rate=16000):
# Use griffin-lim for mel spectrogram inversion
inverse_mel = InverseMelScale(sample_rate=sample_rate)
griffin_lim = GriffinLim(sample_rate=sample_rate)
# Convert mel spectrogram back to waveform
waveform = griffin_lim(inverse_mel(mel_spec))
return waveform |