File size: 4,095 Bytes
070e26e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import h5py
import torch
import random
import yaml
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from diffusion import create_diffusion
from models import DiT
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter  # TensorBoard

# Load hyperparameters from YAML file
with open('config/train.yaml', 'r') as file:
    config = yaml.safe_load(file)

# Create TensorBoard writer
writer = SummaryWriter()

class MelMetaDataset(Dataset):
    def __init__(self, h5_file, mel_frames):
        self.h5_file = h5_file
        self.mel_frames = mel_frames
        with h5py.File(h5_file, 'r') as f:
            self.keys = list(f.keys())

    def __len__(self):
        return len(self.keys)
    
    def pad_mel(self, mel_segment, total_frames):
        if total_frames < self.mel_frames:
            padding_frames = self.mel_frames - total_frames
            mel_segment = F.pad(mel_segment, (0, padding_frames), mode='constant', value=0)
        return mel_segment

    def __getitem__(self, idx):
        key = self.keys[idx]
        with h5py.File(self.h5_file, 'r') as f:
            mel = torch.FloatTensor(f[key]['mel'][:])
            meta_latent = torch.FloatTensor(f[key]['meta'][:])
        
        total_frames = mel.shape[2]
        if total_frames > self.mel_frames:
            start_frame = random.randint(0, total_frames - self.mel_frames)
            mel_segment = mel[:, :, start_frame:start_frame + self.mel_frames]
        else:
            mel_segment = self.pad_mel(mel, total_frames)
        mel_segment = (mel_segment + 10) / 20
        return mel_segment, meta_latent

# Dataset & DataLoader
dataset = MelMetaDataset(config['h5_file_path'], mel_frames=config['mel_frames'])
dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

# Model and optimizer
device = config['device'] if torch.cuda.is_available() else "cpu"
model = DiT(
    input_size=tuple(config['input_size']),
    patch_size=config['patch_size'],
    in_channels=config['in_channels'], 
    hidden_size=config['hidden_size'],
    depth=config['depth'],
    num_heads=config['num_heads'],
)
model.to(device)

# Create diffusion model
diffusion = create_diffusion(timestep_respacing="")

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=config['lr'])

# Create directory to save model checkpoints
os.makedirs(config['checkpoint_dir'], exist_ok=True)

# Training function
def train_model(model, dataloader, optimizer, diffusion, num_epochs, sample_interval):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for step, (mel_segment, meta_latent) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")):
            mel_segment = mel_segment.to(device)
            meta_latent = meta_latent.to(device)
            t = torch.randint(0, diffusion.num_timesteps, (mel_segment.shape[0],), device=device)
            model_kwargs = dict(y=meta_latent)
            loss_dict = diffusion.training_losses(model, mel_segment, t, model_kwargs)
            loss = loss_dict["loss"].mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch + 1}/{num_epochs}: Average Loss: {avg_loss:.4f}")
        writer.add_scalar('Loss/epoch', avg_loss, epoch + 1)

        if (epoch + 1) % sample_interval == 0:
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            checkpoint_path = f"{config['checkpoint_dir']}/model_epoch_{epoch + 1}.pt"
            torch.save(checkpoint, checkpoint_path)
            print(f"Model checkpoint saved at epoch {epoch + 1}")

# Start training
train_model(model, dataloader, optimizer, diffusion, num_epochs=config['num_epochs'], sample_interval=config['sample_interval'])

# Close TensorBoard writer
writer.close()