Staticaliza commited on
Commit
18b9cec
1 Parent(s): 91f5a22

Create model/trainer.py

Browse files
Files changed (1) hide show
  1. model/trainer.py +250 -0
model/trainer.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import gc
5
+ from tqdm import tqdm
6
+ import wandb
7
+
8
+ import torch
9
+ from torch.optim import AdamW
10
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
+ from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
+
13
+ from einops import rearrange
14
+
15
+ from accelerate import Accelerator
16
+ from accelerate.utils import DistributedDataParallelKwargs
17
+
18
+ from ema_pytorch import EMA
19
+
20
+ from model import CFM
21
+ from model.utils import exists, default
22
+ from model.dataset import DynamicBatchSampler, collate_fn
23
+
24
+
25
+ # trainer
26
+
27
+ class Trainer:
28
+ def __init__(
29
+ self,
30
+ model: CFM,
31
+ epochs,
32
+ learning_rate,
33
+ num_warmup_updates = 20000,
34
+ save_per_updates = 1000,
35
+ checkpoint_path = None,
36
+ batch_size = 32,
37
+ batch_size_type: str = "sample",
38
+ max_samples = 32,
39
+ grad_accumulation_steps = 1,
40
+ max_grad_norm = 1.0,
41
+ noise_scheduler: str | None = None,
42
+ duration_predictor: torch.nn.Module | None = None,
43
+ wandb_project = "test_e2-tts",
44
+ wandb_run_name = "test_run",
45
+ wandb_resume_id: str = None,
46
+ last_per_steps = None,
47
+ accelerate_kwargs: dict = dict(),
48
+ ema_kwargs: dict = dict()
49
+ ):
50
+
51
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
52
+
53
+ self.accelerator = Accelerator(
54
+ log_with = "wandb",
55
+ kwargs_handlers = [ddp_kwargs],
56
+ gradient_accumulation_steps = grad_accumulation_steps,
57
+ **accelerate_kwargs
58
+ )
59
+
60
+ if exists(wandb_resume_id):
61
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
62
+ else:
63
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
64
+ self.accelerator.init_trackers(
65
+ project_name = wandb_project,
66
+ init_kwargs=init_kwargs,
67
+ config={"epochs": epochs,
68
+ "learning_rate": learning_rate,
69
+ "num_warmup_updates": num_warmup_updates,
70
+ "batch_size": batch_size,
71
+ "batch_size_type": batch_size_type,
72
+ "max_samples": max_samples,
73
+ "grad_accumulation_steps": grad_accumulation_steps,
74
+ "max_grad_norm": max_grad_norm,
75
+ "gpus": self.accelerator.num_processes,
76
+ "noise_scheduler": noise_scheduler}
77
+ )
78
+
79
+ self.model = model
80
+
81
+ if self.is_main:
82
+ self.ema_model = EMA(
83
+ model,
84
+ include_online_model = False,
85
+ **ema_kwargs
86
+ )
87
+
88
+ self.ema_model.to(self.accelerator.device)
89
+
90
+ self.epochs = epochs
91
+ self.num_warmup_updates = num_warmup_updates
92
+ self.save_per_updates = save_per_updates
93
+ self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
94
+ self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
95
+
96
+ self.batch_size = batch_size
97
+ self.batch_size_type = batch_size_type
98
+ self.max_samples = max_samples
99
+ self.grad_accumulation_steps = grad_accumulation_steps
100
+ self.max_grad_norm = max_grad_norm
101
+
102
+ self.noise_scheduler = noise_scheduler
103
+
104
+ self.duration_predictor = duration_predictor
105
+
106
+ self.optimizer = AdamW(model.parameters(), lr=learning_rate)
107
+ self.model, self.optimizer = self.accelerator.prepare(
108
+ self.model, self.optimizer
109
+ )
110
+
111
+ @property
112
+ def is_main(self):
113
+ return self.accelerator.is_main_process
114
+
115
+ def save_checkpoint(self, step, last=False):
116
+ self.accelerator.wait_for_everyone()
117
+ if self.is_main:
118
+ checkpoint = dict(
119
+ model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
120
+ optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
121
+ ema_model_state_dict = self.ema_model.state_dict(),
122
+ scheduler_state_dict = self.scheduler.state_dict(),
123
+ step = step
124
+ )
125
+ if not os.path.exists(self.checkpoint_path):
126
+ os.makedirs(self.checkpoint_path)
127
+ if last == True:
128
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
129
+ print(f"Saved last checkpoint at step {step}")
130
+ else:
131
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
132
+
133
+ def load_checkpoint(self):
134
+ if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
135
+ return 0
136
+
137
+ self.accelerator.wait_for_everyone()
138
+ if "model_last.pt" in os.listdir(self.checkpoint_path):
139
+ latest_checkpoint = "model_last.pt"
140
+ else:
141
+ latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
+ # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
144
+
145
+ if self.is_main:
146
+ self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
147
+
148
+ if 'step' in checkpoint:
149
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
150
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
151
+ if self.scheduler:
152
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
153
+ step = checkpoint['step']
154
+ else:
155
+ checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
156
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
157
+ step = 0
158
+
159
+ del checkpoint; gc.collect()
160
+ return step
161
+
162
+ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
163
+
164
+ if exists(resumable_with_seed):
165
+ generator = torch.Generator()
166
+ generator.manual_seed(resumable_with_seed)
167
+ else:
168
+ generator = None
169
+
170
+ if self.batch_size_type == "sample":
171
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
172
+ batch_size=self.batch_size, shuffle=True, generator=generator)
173
+ elif self.batch_size_type == "frame":
174
+ self.accelerator.even_batches = False
175
+ sampler = SequentialSampler(train_dataset)
176
+ batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
177
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
178
+ batch_sampler=batch_sampler)
179
+ else:
180
+ raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
181
+
182
+ # accelerator.prepare() dispatches batches to devices;
183
+ # which means the length of dataloader calculated before, should consider the number of devices
184
+ warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
185
+ # otherwise by default with split_batches=False, warmup steps change with num_processes
186
+ total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
187
+ decay_steps = total_steps - warmup_steps
188
+ warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
189
+ decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
190
+ self.scheduler = SequentialLR(self.optimizer,
191
+ schedulers=[warmup_scheduler, decay_scheduler],
192
+ milestones=[warmup_steps])
193
+ train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
194
+ start_step = self.load_checkpoint()
195
+ global_step = start_step
196
+
197
+ if exists(resumable_with_seed):
198
+ orig_epoch_step = len(train_dataloader)
199
+ skipped_epoch = int(start_step // orig_epoch_step)
200
+ skipped_batch = start_step % orig_epoch_step
201
+ skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
202
+ else:
203
+ skipped_epoch = 0
204
+
205
+ for epoch in range(skipped_epoch, self.epochs):
206
+ self.model.train()
207
+ if exists(resumable_with_seed) and epoch == skipped_epoch:
208
+ progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
209
+ initial=skipped_batch, total=orig_epoch_step)
210
+ else:
211
+ progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
212
+
213
+ for batch in progress_bar:
214
+ with self.accelerator.accumulate(self.model):
215
+ text_inputs = batch['text']
216
+ mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
217
+ mel_lengths = batch["mel_lengths"]
218
+
219
+ # TODO. add duration predictor training
220
+ if self.duration_predictor is not None and self.accelerator.is_local_main_process:
221
+ dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
222
+ self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
223
+
224
+ loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
225
+ self.accelerator.backward(loss)
226
+
227
+ if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
228
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
229
+
230
+ self.optimizer.step()
231
+ self.scheduler.step()
232
+ self.optimizer.zero_grad()
233
+
234
+ if self.is_main:
235
+ self.ema_model.update()
236
+
237
+ global_step += 1
238
+
239
+ if self.accelerator.is_local_main_process:
240
+ self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
241
+
242
+ progress_bar.set_postfix(step=str(global_step), loss=loss.item())
243
+
244
+ if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
245
+ self.save_checkpoint(global_step)
246
+
247
+ if global_step % self.last_per_steps == 0:
248
+ self.save_checkpoint(global_step, last=True)
249
+
250
+ self.accelerator.end_training()