wasmdashai commited on
Commit
7d0e0a6
·
verified ·
1 Parent(s): 1e1ef94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +320 -320
app.py CHANGED
@@ -38,8 +38,6 @@ from torch.cuda.amp import autocast, GradScaler
38
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
 
40
 
41
- feature_extractor = VitsFeatureExtractor()
42
-
43
  # sgl=get_state_grad_loss(k1=True,#generator=False,
44
  # discriminator=False,
45
  # duration=False
@@ -182,188 +180,63 @@ def get_data_loader(train_dataset_dirs,eval_dataset_dir,full_generation_dir,dev
182
  device = device)
183
  return ctrain_datasets,eval_dataset,full_generation_dataset
184
  global_step=0
185
- @spaces.GPU
186
- def trainer_to_cuda(self,
187
- ctrain_datasets = None,
188
- eval_dataset = None,
189
- full_generation_dataset = None,
190
- feature_extractor = VitsFeatureExtractor(),
191
- training_args = None,
192
- full_generation_sample_index= 0,
193
- project_name = "Posterior_Decoder_Finetuning",
194
- wandbKey = "782b6a6e82bbb5a5348de0d3c7d40d1e76351e79",
195
- is_used_text_encoder=True,
196
- is_used_posterior_encode=True,
197
- dict_state_grad_loss=None,
198
- nk=1,
199
- path_save_model='./',
200
- maf=None,
201
- n_back_save_model=3000,
202
- start_speeker=0,
203
- end_speeker=1,
204
- n_epoch=0,
205
-
206
-
207
-
208
- ):
209
-
210
-
211
- # os.makedirs(training_args.output_dir,exist_ok=True)
212
- # logger = logging.getLogger(f"{__name__} Training")
213
- # log_level = training_args.get_process_log_level()
214
- # logger.setLevel(log_level)
215
-
216
- # # wandb.login(key= wandbKey)
217
- # # wandb.init(project= project_name,config = training_args.to_dict())
218
- if dict_state_grad_loss is None:
219
- dict_state_grad_loss=get_state_grad_loss()
220
- global global_step
221
-
222
-
223
-
224
- set_seed(training_args.seed)
225
- scaler = GradScaler(enabled=training_args.fp16)
226
- self.config.save_pretrained(training_args.output_dir)
227
- len_db=len(ctrain_datasets)
228
- self.full_generation_sample = full_generation_dataset[full_generation_sample_index]
229
-
230
- # init optimizer, lr_scheduler
231
- for disc in self.discriminator.discriminators:
232
- disc.apply_weight_norm()
233
- self.decoder.apply_weight_norm()
234
- # torch.nn.utils.weight_norm(self.decoder.conv_pre)
235
- # torch.nn.utils.weight_norm(self.decoder.conv_post)
236
- for flow in self.flow.flows:
237
- torch.nn.utils.weight_norm(flow.conv_pre)
238
- torch.nn.utils.weight_norm(flow.conv_post)
239
-
240
- discriminator=self.discriminator
241
- self.discriminator=None
242
-
243
- optimizer = torch.optim.AdamW(
244
- self.parameters(),
245
- training_args.learning_rate,
246
- betas=[training_args.adam_beta1, training_args.adam_beta2],
247
- eps=training_args.adam_epsilon,
248
- )
249
-
250
- # hack to be able to train on multiple device
251
-
252
-
253
- disc_optimizer = torch.optim.AdamW(
254
- discriminator.parameters(),
255
- training_args.d_learning_rate,
256
- betas=[training_args.d_adam_beta1, training_args.d_adam_beta2],
257
- eps=training_args.adam_epsilon,
258
- )
259
- lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
260
- optimizer, gamma=training_args.lr_decay, last_epoch=-1
261
- )
262
- disc_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
263
- disc_optimizer, gamma=training_args.lr_decay, last_epoch=-1)
264
-
265
-
266
- # logger.info("***** Running training *****")
267
- # logger.info(f" Num Epochs = {training_args.num_train_epochs}")
268
-
269
-
270
- #.......................loop training............................
271
-
272
-
273
-
274
- for epoch in range(training_args.num_train_epochs):
275
- train_losses_sum = 0
276
- loss_gen=0
277
- loss_des=0
278
- loss_durationsall=0
279
- loss_melall=0
280
- loss_klall=0
281
- loss_fmapsall=0
282
- lr_scheduler.step()
283
-
284
- disc_lr_scheduler.step()
285
- train_dataset,speaker_id=ctrain_datasets[epoch%len_db]
286
- print(f" Num Epochs = {int((epoch+n_epoch)/len_db)}, speaker_id DB ={speaker_id}")
287
- num_div_proc=int(len(train_dataset)/10)+1
288
- print(' -process traning : [',end='')
289
-
290
-
291
-
292
-
293
-
294
-
295
-
296
- for step, batch in enumerate(train_dataset):
297
- # if speaker_id==None:
298
- # if step<3 :continue
299
-
300
- # if step>200:break
301
-
302
-
303
- batch=covert_cuda_batch(batch)
304
- displayloss={}
305
-
306
- with autocast(enabled=training_args.fp16):
307
- speaker_embeddings=get_embed_speaker(self,batch["speaker_id"] if speaker_id ==None else speaker_id )
308
-
309
-
310
- waveform,ids_slice,log_duration,prior_latents,posterior_log_variances,prior_means,prior_log_variances,labels_padding_mask = self.forward_train(
311
- input_ids=batch["input_ids"],
312
- attention_mask=batch["attention_mask"],
313
- labels=batch["labels"],
314
- labels_attention_mask=batch["labels_attention_mask"],
315
- text_encoder_output =None ,
316
- posterior_encode_output=None ,
317
- return_dict=True,
318
- monotonic_alignment_function= maf,
319
- speaker_embeddings=speaker_embeddings
320
- )
321
 
322
- mel_scaled_labels = batch["mel_scaled_input_features"]
323
- mel_scaled_target = self.slice_segments(mel_scaled_labels, ids_slice,self.segment_size)
324
- mel_scaled_generation = feature_extractor._torch_extract_fbank_features(waveform.squeeze(1))[1]
325
 
326
- target_waveform = batch["waveform"].transpose(1, 2)
327
- target_waveform = self.slice_segments(
328
- target_waveform,
329
- ids_slice * feature_extractor.hop_length,
330
- self.config.segment_size
331
- )
332
-
333
- discriminator_target, fmaps_target = discriminator(target_waveform)
334
- discriminator_candidate, fmaps_candidate = discriminator(waveform.detach())
335
- with autocast(enabled=False):
336
- if dict_state_grad_loss['discriminator']:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
 
339
  loss_disc, loss_real_disc, loss_fake_disc = discriminator_loss(
340
  discriminator_target, discriminator_candidate
341
  )
342
 
343
- dk={"step_loss_disc": loss_disc.detach().item(),
344
- "step_loss_real_disc": loss_real_disc.detach().item(),
345
- "step_loss_fake_disc": loss_fake_disc.detach().item()}
346
- displayloss['dict_loss_discriminator']=dk
347
  loss_dd = loss_disc# + loss_real_disc + loss_fake_disc
348
 
349
  # loss_dd.backward()
350
 
351
- disc_optimizer.zero_grad()
352
- scaler.scale(loss_dd).backward()
353
- scaler.unscale_(disc_optimizer )
354
- grad_norm_d = clip_grad_value_(discriminator.parameters(), None)
355
- scaler.step(disc_optimizer)
356
- loss_des+=grad_norm_d
357
-
358
 
359
- with autocast(enabled=training_args.fp16):
360
 
361
  # backpropagate
362
 
363
-
364
-
365
-
366
-
367
  discriminator_target, fmaps_target = discriminator(target_waveform)
368
 
369
  discriminator_candidate, fmaps_candidate = discriminator(waveform.detach())
@@ -377,30 +250,30 @@ def trainer_to_cuda(self,
377
  labels_padding_mask,
378
  )
379
  loss_kl=loss_kl*training_args.weight_kl
380
- loss_klall+=loss_kl.detach().item()
381
  #if displayloss['loss_kl']>=0:
382
  # loss_kl.backward()
383
 
384
  if dict_state_grad_loss['mel']:
385
  loss_mel = torch.nn.functional.l1_loss(mel_scaled_target, mel_scaled_generation)*training_args.weight_mel
386
- loss_melall+= loss_mel.detach().item()
387
  # train_losses_sum = train_losses_sum + displayloss['loss_mel']
388
  # if displayloss['loss_mel']>=0:
389
  # loss_mel.backward()
390
 
391
  if dict_state_grad_loss['duration']:
392
  loss_duration=torch.sum(log_duration)*training_args.weight_duration
393
- loss_durationsall+=loss_duration.detach().item()
394
  # if displayloss['loss_duration']>=0:
395
  # loss_duration.backward()
396
  if dict_state_grad_loss['generator']:
397
  loss_fmaps = feature_loss(fmaps_target, fmaps_candidate)
398
  loss_gen, losses_gen = generator_loss(discriminator_candidate)
399
  loss_gen=loss_gen * training_args.weight_gen
400
- displayloss['loss_gen'] = loss_gen.detach().item()
401
  # loss_gen.backward(retain_graph=True)
402
  loss_fmaps=loss_fmaps * training_args.weight_fmaps
403
- displayloss['loss_fmaps'] = loss_fmaps.detach().item()
404
  # loss_fmaps.backward(retain_graph=True)
405
  total_generator_loss = (
406
  loss_duration
@@ -410,111 +283,250 @@ def trainer_to_cuda(self,
410
  + loss_gen
411
  )
412
  # total_generator_loss.backward()
413
- optimizer.zero_grad()
414
- scaler.scale(total_generator_loss).backward()
415
- scaler.unscale_(optimizer)
416
- grad_norm_g = clip_grad_value_(self.parameters(), None)
417
- scaler.step(optimizer)
418
- scaler.update()
419
- loss_gen+=grad_norm_g
420
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
 
 
 
422
 
 
 
 
 
 
 
 
423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
 
425
 
426
- # optimizer.step()
 
427
 
 
 
 
428
 
 
 
 
 
 
 
 
429
 
 
430
 
431
- # print(f"TRAINIG - batch {step}, waveform {(batch['waveform'].shape)}, lr {lr_scheduler.get_last_lr()[0]}... ")
432
- # print(f"display loss function enable :{displayloss}")
 
 
 
 
 
 
 
 
 
 
 
433
 
434
- global_step +=1
435
- if step%num_div_proc==0:
436
- print('==',end='')
437
 
438
- # validation
439
 
440
- do_eval = training_args.do_eval and (global_step % training_args.eval_steps == 0)
441
- if do_eval:
442
- speaker_id_c=int(torch.randint(start_speeker,end_speeker,size=(1,))[0])
443
- logger.info("Running validation... ")
444
- eval_losses_sum = 0
445
- cc=0;
446
- for step, batch in enumerate(eval_dataset):
447
- break
448
- if cc>2: break
449
- cc+=1
450
- with torch.no_grad():
451
- model_outputs = self.forward(
452
- input_ids=batch["input_ids"],
453
- attention_mask=batch["attention_mask"],
454
- labels=batch["labels"],
455
- labels_attention_mask=batch["labels_attention_mask"],
456
- speaker_id=batch["speaker_id"],
457
 
 
458
 
459
- return_dict=True,
460
 
461
- )
462
 
463
- mel_scaled_labels = batch["mel_scaled_input_features"]
464
- mel_scaled_target = self.slice_segments(mel_scaled_labels, model_outputs.ids_slice,self.segment_size)
465
- mel_scaled_generation = feature_extractor._torch_extract_fbank_features(model_outputs.waveform.squeeze(1))[1]
466
- loss = loss_mel.detach().item()
467
- eval_losses_sum +=loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
- loss_mel = torch.nn.functional.l1_loss(mel_scaled_target, mel_scaled_generation)
470
- print(f"VALIDATION - batch {step}, waveform {(batch['waveform'].shape)}, step_loss_mel {loss} ... ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
 
 
473
 
474
- with torch.no_grad():
475
- full_generation_sample = self.full_generation_sample
476
- full_generation =self.forward(
477
- input_ids =full_generation_sample["input_ids"],
478
- attention_mask=full_generation_sample["attention_mask"],
479
- speaker_id=speaker_id_c
480
- )
481
 
482
- full_generation_waveform = full_generation.waveform.cpu().numpy()
483
 
484
- wandb.log({
485
- "eval_losses": eval_losses_sum,
486
- "full generations samples": [
487
- wandb.Audio(w.reshape(-1), caption=f"Full generation sample {epoch}", sample_rate=16000)
488
- for w in full_generation_waveform],})
489
- step+=1
490
- # wandb.log({"train_losses":loss_melall})
491
- wandb.log({"loss_gen":loss_gen/step})
492
- wandb.log({"loss_des":loss_des/step})
493
- wandb.log({"loss_duration":loss_durationsall/step})
494
- wandb.log({"loss_mel":loss_melall/step})
495
- wandb.log({f"loss_kl_db{speaker_id}":loss_klall/step})
496
- print(']',end='')
497
 
 
498
 
 
 
 
499
 
 
 
 
500
 
501
- # self.save_pretrained(path_save_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
- self.discriminator=discriminator
505
- for disc in self.discriminator.discriminators:
506
- disc.remove_weight_norm()
507
- self.decoder.remove_weight_norm()
508
- # torch.nn.utils.remove_weight_norm(self.decoder.conv_pre)
509
- # torch.nn.utils.remove_weight_norm(self.decoder.conv_post)
510
- for flow in self.flow.flows:
511
- torch.nn.utils.remove_weight_norm(flow.conv_pre)
512
- torch.nn.utils.remove_weight_norm(flow.conv_post)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
 
514
- self.save_pretrained(path_save_model)
515
 
516
- # logger.info("Running final full generations samples... ")
 
517
 
 
518
 
519
 
520
  # logger.info("***** Training / Inference Done *****")
@@ -560,85 +572,73 @@ train_dataset_dirs=[
560
 
561
 
562
  dir_model='wasmdashai/vits-ar-huba-fine'
 
 
 
 
 
 
 
563
 
564
-
565
-
566
- global_step=0
567
- wandb.login(key= "782b6a6e82bbb5a5348de0d3c7d40d1e76351e79")
568
-
569
-
570
- ctrain_datasets,eval_dataset,full_generation_dataset=get_data_loader(train_dataset_dirs = train_dataset_dirs,
571
- eval_dataset_dir = os.path.join(dataset_dir,'eval'),
572
- full_generation_dir = os.path.join(dataset_dir,'full_generation'),
573
- device="cuda")
574
-
575
- print('load Data')
576
- wandb.init(project= 'AZ')
577
-
578
- print('wandb')
579
- model=VitsModel.from_pretrained(dir_model,token=token).to("cuda")
580
- print('loadeed')
581
  @spaces.GPU
582
- def greet(text,id):
583
- global GK
584
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, VITSTrainingArguments))
585
- json_file = os.path.abspath('VitsModelSplit/finetune_config_ara.json')
586
- model_args, data_args, training_args = parser.parse_json_file(json_file = json_file)
587
- print('start')
588
- sgl=get_state_grad_loss(mel=True,
589
- # generator=False,
590
- # discriminator=False,
591
- duration=False)
592
-
593
- print(training_args)
594
- training_args.num_train_epochs=1000
595
- training_args.fp16=True
596
- training_args.eval_steps=300
597
- training_args.weight_kl=1
598
- training_args.d_learning_rate=2e-4
599
- training_args.learning_rate=2e-4
600
- training_args.weight_mel=45
601
- training_args.num_train_epochs=4
602
- training_args.eval_steps=1000
603
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
604
- print(device)
605
-
606
-
607
- for i in range(10000):
608
- # model.train(True)
609
- print(f'clcye epochs ={i}')
610
- yield f'clcye epochs ={i}'
611
-
612
- model=VitsModel.from_pretrained(dir_model,token=token).to("cuda")
613
 
614
- #dir_model_save=dir_model+'/vend'
615
-
616
-
617
- trainer_to_cuda(model,
618
- ctrain_datasets = ctrain_datasets,
619
- eval_dataset = eval_dataset,
620
- full_generation_dataset = ctrain_datasets[0][0],
621
- feature_extractor = VitsFeatureExtractor(),
622
- training_args = training_args,
623
- full_generation_sample_index= -1,
624
- project_name = "AZ",
625
- wandbKey = "782b6a6e82bbb5a5348de0d3c7d40d1e76351e79",
626
- is_used_text_encoder=True,
627
- is_used_posterior_encode=True,
628
- # dict_state_grad_loss=sgl,
629
- nk=50,
630
- path_save_model=dir_model,
631
- maf=monotonic_align.maximum_path,
632
 
633
- n_back_save_model=3000,
634
- start_speeker=0,
635
- end_speeker=1,
636
- n_epoch=i*training_args.num_train_epochs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
 
639
- )
640
 
 
641
 
642
- demo = gr.Interface(fn=greet, inputs=["text","text"], outputs="text")
643
- demo.launch()
644
 
 
38
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
 
40
 
 
 
41
  # sgl=get_state_grad_loss(k1=True,#generator=False,
42
  # discriminator=False,
43
  # duration=False
 
180
  device = device)
181
  return ctrain_datasets,eval_dataset,full_generation_dataset
182
  global_step=0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
 
 
 
184
 
185
+ def train_step(batch,models=[],optimizers=[], training_args=None,tools=[]):
186
+ self,discriminator=models
187
+ optimizer,disc_optimizer,scaler=optimizers
188
+ feature_extractor,maf,dict_state_grad_loss=tools
189
+
190
+ with autocast(enabled=training_args.fp16):
191
+ speaker_embeddings=get_embed_speaker(model,batch["speaker_id"])
192
+ waveform,ids_slice,log_duration,prior_latents,posterior_log_variances,prior_means,prior_log_variances,labels_padding_mask = self.forward_train(
193
+ input_ids=batch["input_ids"],
194
+ attention_mask=batch["attention_mask"],
195
+ labels=batch["labels"],
196
+ labels_attention_mask=batch["labels_attention_mask"],
197
+ text_encoder_output =None ,
198
+ posterior_encode_output=None ,
199
+ return_dict=True,
200
+ monotonic_alignment_function=maf,
201
+ speaker_embeddings=speaker_embeddings
202
+
203
+ )
204
+ mel_scaled_labels = batch["mel_scaled_input_features"]
205
+ mel_scaled_target = self.slice_segments(mel_scaled_labels, ids_slice,self.segment_size)
206
+ mel_scaled_generation = feature_extractor._torch_extract_fbank_features(waveform.squeeze(1))[1]
207
+
208
+ target_waveform = batch["waveform"].transpose(1, 2)
209
+ target_waveform = self.slice_segments(
210
+ target_waveform,
211
+ ids_slice * feature_extractor.hop_length,
212
+ self.config.segment_size
213
+ )
214
+
215
+ discriminator_target, fmaps_target = discriminator(target_waveform)
216
+ discriminator_candidate, fmaps_candidate = discriminator(waveform.detach())
217
+ with autocast(enabled=False):
218
+ if dict_state_grad_loss['discriminator']:
219
 
220
 
221
  loss_disc, loss_real_disc, loss_fake_disc = discriminator_loss(
222
  discriminator_target, discriminator_candidate
223
  )
224
 
 
 
 
 
225
  loss_dd = loss_disc# + loss_real_disc + loss_fake_disc
226
 
227
  # loss_dd.backward()
228
 
229
+ disc_optimizer.zero_grad()
230
+ scaler.scale(loss_dd).backward()
231
+ scaler.unscale_(disc_optimizer )
232
+ grad_norm_d = clip_grad_value_(discriminator.parameters(), None)
233
+ scaler.step(disc_optimizer)
234
+ loss_des=grad_norm_d
 
235
 
236
+ with autocast(enabled=training_args.fp16):
237
 
238
  # backpropagate
239
 
 
 
 
 
240
  discriminator_target, fmaps_target = discriminator(target_waveform)
241
 
242
  discriminator_candidate, fmaps_candidate = discriminator(waveform.detach())
 
250
  labels_padding_mask,
251
  )
252
  loss_kl=loss_kl*training_args.weight_kl
253
+ loss_klall=loss_kl.detach().item()
254
  #if displayloss['loss_kl']>=0:
255
  # loss_kl.backward()
256
 
257
  if dict_state_grad_loss['mel']:
258
  loss_mel = torch.nn.functional.l1_loss(mel_scaled_target, mel_scaled_generation)*training_args.weight_mel
259
+ loss_melall= loss_mel.detach().item()
260
  # train_losses_sum = train_losses_sum + displayloss['loss_mel']
261
  # if displayloss['loss_mel']>=0:
262
  # loss_mel.backward()
263
 
264
  if dict_state_grad_loss['duration']:
265
  loss_duration=torch.sum(log_duration)*training_args.weight_duration
266
+ loss_durationsall=loss_duration.detach().item()
267
  # if displayloss['loss_duration']>=0:
268
  # loss_duration.backward()
269
  if dict_state_grad_loss['generator']:
270
  loss_fmaps = feature_loss(fmaps_target, fmaps_candidate)
271
  loss_gen, losses_gen = generator_loss(discriminator_candidate)
272
  loss_gen=loss_gen * training_args.weight_gen
273
+
274
  # loss_gen.backward(retain_graph=True)
275
  loss_fmaps=loss_fmaps * training_args.weight_fmaps
276
+
277
  # loss_fmaps.backward(retain_graph=True)
278
  total_generator_loss = (
279
  loss_duration
 
283
  + loss_gen
284
  )
285
  # total_generator_loss.backward()
286
+ optimizer.zero_grad()
287
+ scaler.scale(total_generator_loss).backward()
288
+ scaler.unscale_(optimizer)
289
+ grad_norm_g = clip_grad_value_(self.parameters(), None)
290
+ scaler.step(optimizer)
291
+ scaler.update()
292
+ loss_gen=grad_norm_g
293
+
294
+ return loss_gen,loss_des,loss_durationsall,loss_melall,loss_klall
295
+
296
+
297
+
298
+ def train_epoch(obtrainer,index_db=0,epoch=0,idspeakers=[],full_generation_sample_index=-1):
299
+ train_losses_sum = 0
300
+ loss_genall=0
301
+ loss_desall=0
302
+ loss_durationsall=0
303
+ loss_melall=0
304
+ loss_klall=0
305
+ loss_fmapsall=0
306
+ start_speeker,end_speeker=idspeakers
307
+
308
 
309
+ datatrain=obtrainer.DataSets['train'][index_db]
310
+ lr_scheduler,disc_lr_scheduler=obtrainer.lr_schedulers
311
+ lr_scheduler.step()
312
 
313
+ disc_lr_scheduler.step()
314
+ train_dataset,speaker_id=datatrain
315
+ print(f" Num Epochs = {epoch}, speaker_id DB ={speaker_id}")
316
+ num_div_proc=int(len(train_dataset)/10)+1
317
+ print(' -process traning : [',end='')
318
+ full_generation_sample =obtrainer.DataSets['full_generation'][full_generation_sample_index]
319
+
320
 
321
+
322
+ for step, batch in enumerate(train_dataset):
323
+ loss_gen,loss_des,loss_durationsa,loss_mela,loss_kl=train_step(batch,
324
+ models=obtrainer.models,
325
+ optimizers=obtrainer.optimizers,
326
+ training_args=obtrainer.training_args,
327
+ tools=obtrainer.tools)
328
+ loss_genall+=loss_gen
329
+ loss_desall+=loss_des
330
+ loss_durationsall+=loss_durationsa
331
+ loss_melall+=loss_mela
332
+ loss_klall+=loss_kl
333
+
334
+ obtrainer.global_step +=1
335
+ if step%num_div_proc==0:
336
+ print('==',end='')
337
 
338
+ # validation
339
 
340
+ do_eval = obtrainer.training_args.do_eval and (obtrainer.global_step % obtrainer.training_args.eval_steps == 0)
341
+
342
 
343
+ if do_eval:
344
+ speaker_id_c=int(torch.randint(start_speeker,end_speeker,size=(1,))[0])
345
+ model=obtrainer.model[0]
346
 
347
+ with torch.no_grad():
348
+
349
+ full_generation =model.forward(
350
+ input_ids =full_generation_sample["input_ids"],
351
+ attention_mask=full_generation_sample["attention_mask"],
352
+ speaker_id=speaker_id_c
353
+ )
354
 
355
+ full_generation_waveform = full_generation.waveform.cpu().numpy()
356
 
357
+ wandb.log({
358
+ "full generations samples": [
359
+ wandb.Audio(w.reshape(-1), caption=f"Full generation sample {epoch}", sample_rate=16000)
360
+ for w in full_generation_waveform],})
361
+ step+=1
362
+ # wandb.log({"train_losses":loss_melall})
363
+ wandb.log({"loss_gen":loss_genall/step})
364
+ wandb.log({"loss_des":loss_desall/step})
365
+ wandb.log({"loss_duration":loss_durationsall/step})
366
+ wandb.log({"loss_mel":loss_melall/step})
367
+ wandb.log({f"loss_kl_db{speaker_id}":loss_klall/step})
368
+ print(']',end='')
369
+
370
 
371
+
 
 
372
 
 
373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
+
376
 
 
377
 
 
378
 
379
+ def load_training_args(path):
380
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, VITSTrainingArguments))
381
+ json_file = os.path.abspath(path)
382
+ model_args, data_args, training_args = parser.parse_json_file(json_file = json_file)
383
+ return training_args
384
+ def load_tools():
385
+ feature_extractor = VitsFeatureExtractor()
386
+ dict_state_grad_loss=get_state_grad_loss()
387
+ return feature_extractor,monotonic_align.maximum_path,dict_state_grad_loss
388
+
389
+
390
+ class TrinerModelVITS:
391
+ def __init__(self,dir_model="",
392
+ path_training_args="",
393
+ train_dataset_dirs=[],
394
+ eval_dataset_dir="",
395
+ full_generation_dir="",
396
+ token="",
397
+
398
+
399
+ device="cpu"):
400
+ self.device=device
401
+ self.dir_model=dir_model
402
+ self.path_training_args=path_training_args
403
+ self.stute_mode=False
404
+ self.token=token
405
+
406
+
407
+ self.epoch_count=0
408
+ self.global_step=0
409
+
410
+
411
 
412
+ def init_Starting(self):
413
+ self.training_args=load_training_args(self.path_training_args)
414
+ self.stute_mode=False
415
+
416
+ self.load_dataset(train_dataset_dirs,eval_dataset_dir,full_generation_dir)
417
+ self.len_dataset=len(self.DataSets['train'])
418
+ def init_training(self):
419
+
420
+ self.load_model()
421
+ self.initialize_training_components()
422
+ self.epoch_count=0
423
+
424
+
425
+ def load_model(self):
426
+ self.model=VitsModel.from_pretrained(self.dir_model,token=self.token).to(self.device)
427
+ self.model.setMfA(monotonic_align.maximum_path)
428
+
429
+ def init_wandb(self):
430
+ wandb.login(key= "782b6a6e82bbb5a5348de0d3c7d40d1e76351e79")
431
+ wandb.init(project= 'HugfaceTraining',config = self.training_args.to_dict())
432
+
433
+
434
+ def load_dataset(self,train_dataset_dirs,eval_dataset_dir,full_generation_dir):
435
+ ctrain_datasets,eval_dataset,full_generation_dataset=get_data_loader(train_dataset_dirs = train_dataset_dirs,
436
+ eval_dataset_dir = os.path.join(dataset_dir,'eval'),
437
+ full_generation_dir = os.path.join(dataset_dir,'full_generation'),
438
+ device=self.device)
439
+ self.DataSets={'train':ctrain_datasets,'eval':eval_dataset,'full_generation':full_generation_dataset}
440
 
441
 
442
+
443
 
 
 
 
 
 
 
 
444
 
 
445
 
446
+
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
+ def initialize_training_components(self):
449
 
450
+
451
+
452
+ self.training_args=training_args
453
 
454
+ set_seed(training_args.seed)
455
+ scaler = GradScaler(enabled=training_args.fp16)
456
+
457
 
458
+ # Initialize optimizer, lr_scheduler
459
+ for disc in self.model.discriminator.discriminators:
460
+ disc.apply_weight_norm()
461
+ self.model.decoder.apply_weight_norm()
462
+ # torch.nn.utils.weight_norm(self.decoder.conv_pre)
463
+ # torch.nn.utils.weight_norm(self.decoder.conv_post)
464
+ for flow in self.model.flow.flows:
465
+ torch.nn.utils.weight_norm(flow.conv_pre)
466
+ torch.nn.utils.weight_norm(flow.conv_post)
467
+
468
+ discriminator = self.model.discriminator
469
+ self.model.discriminator = None
470
+
471
+ optimizer = torch.optim.AdamW(
472
+ self.model.parameters(),
473
+ training_args.learning_rate,
474
+ betas=[training_args.adam_beta1, training_args.adam_beta2],
475
+ eps=training_args.adam_epsilon,
476
+ )
477
 
478
+ # Hack to be able to train on multiple device
479
+ disc_optimizer = torch.optim.AdamW(
480
+ discriminator.parameters(),
481
+ training_args.d_learning_rate,
482
+ betas=[training_args.d_adam_beta1, training_args.d_adam_beta2],
483
+ eps=training_args.adam_epsilon,
484
+ )
485
+ lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
486
+ optimizer, gamma=training_args.lr_decay, last_epoch=-1
487
+ )
488
+ disc_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
489
+ disc_optimizer, gamma=training_args.lr_decay, last_epoch=-1
490
+ )
491
+ self.models=(self.model,discriminator)
492
+ self.optimizers=(optimizer,disc_optimizer,scaler)
493
+ self.lr_schedulers=(lr_scheduler,disc_lr_scheduler)
494
+ self.tools=load_tools()
495
+ self.stute_mode=True
496
+
497
 
498
+
499
+ def save_pretrained(self,path_save_model):
500
+
501
+ model,discriminator=self.models
502
+
503
+ model.discriminator=discriminator
504
+ for disc in model.discriminator.discriminators:
505
+ disc.remove_weight_norm()
506
+ model.decoder.remove_weight_norm()
507
+ # torch.nn.utils.remove_weight_norm(self.decoder.conv_pre)
508
+ # torch.nn.utils.remove_weight_norm(self.decoder.conv_post)
509
+ for flow in model.flow.flows:
510
+ torch.nn.utils.remove_weight_norm(flow.conv_pre)
511
+ torch.nn.utils.remove_weight_norm(flow.conv_post)
512
+
513
+ self.input_save_pretrained(path_save_model,token=self.token)
514
+
515
+
516
+ def run_train_epoch(self):
517
+ index_db=self.epoch_count%self.len_dataset
518
+ train_epoch(self,index_db=index_db,epoch=self.epoch_count,idspeakers=(0,1),full_generation_sample_index=-1)
519
+ self.epoch_count+=1
520
+ return f'epoch_count:{self.epoch_count},global_step:{self.global_step},index_db"{index_db}'
521
+
522
+
523
+
524
 
 
525
 
526
+
527
+ # return (self.model,discriminator),(optimizer, disc_optimizer), (lr_scheduler, disc_lr_scheduler)
528
 
529
+
530
 
531
 
532
  # logger.info("***** Training / Inference Done *****")
 
572
 
573
 
574
  dir_model='wasmdashai/vits-ar-huba-fine'
575
+ pro=TrinerModelVITS(dir_model=dir_model,
576
+ path_training_args='VitsModelSplit/finetune_config_ara.json',
577
+ train_dataset_dirs = train_dataset_dirs,
578
+ eval_dataset_dir = os.path.join(dataset_dir,'eval'),
579
+ full_generation_dir = os.path.join(dataset_dir,'full_generation'),
580
+ device=device
581
+ )
582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  @spaces.GPU
584
+ def run_train_epoch(num):
585
+ for i in range(10):
586
+ # model.train(True)
587
+ yield pro.run_train_epoch()
588
+
589
+ @spaces.GPU
590
+ def init_training():
591
+ pro.init_training()
592
+ return pro.dir_model,'init_training'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
 
594
+ @spaces.GPU
595
+ def init_Starting():
596
+ pro.init_Starting()
597
+ return 'init_Starting'
598
+ @spaces.GPU
599
+ def init_wandb():
600
+ pro.init_wandb()
601
+ return 'init_wandb'
 
 
 
 
 
 
 
 
 
 
602
 
603
+ @spaces.GPU
604
+ def save_pretrained(path):
605
+ pro.save_pretrained(path)
606
+ pro.init_training()
607
+ return 'save_pretrained'
608
+
609
+ with gr.Blocks() as interface:
610
+ with gr.Accordion("init_Starting ", open=False):
611
+ btn_init = gr.Button("init start")
612
+ output_init = gr.Textbox(label="init")
613
+ btn_init.click(fn=init_Starting,inputs=[],outputs=[output_init])
614
+ with gr.Accordion("init_wandb ", open=False):
615
+ btn_init_wandb = gr.Button("nit_wandb")
616
+ output_initbtn_init_wandb = gr.Textbox(label="init")
617
+ btn_init_wandb.click(fn=init_training,inputs=[],outputs=[output_initbtn_init_wandb])
618
+
619
+ with gr.Accordion("init_training ", open=False):
620
+ btn_init_train = gr.Button("init init_train")
621
+ output_btn_init_train = gr.Textbox(label="init")
622
+ # btn_init_train.click(fn=init_training,inputs=[],outputs=[output_btn_init_train])
623
 
624
+ with gr.Accordion("run_train_epoch ", open=False):
625
+ btn_run_train_epoch = gr.Button("run_train_epoch")
626
+ input_run_train_epoch = gr.Number(label="number _train_epoch")
627
+ output_run_train_epoch = gr.Textbox(label="run_train_epoch")
628
+ btn_run_train_epoch.click(fn=run_train_epoch,inputs=[input_run_train_epoch],outputs=[output_run_train_epoch])
629
+
630
+ with gr.Accordion("save_pretrained ", open=False):
631
+ btn_save_pretrained = gr.Button("save_pretrained")
632
+ input_save_pretrained = gr.Textbox(label="save_pretrained")
633
+ output_save_pretrained = gr.Textbox(label="save_pretrained")
634
+ btn_save_pretrained.click(fn=save_pretrained,inputs=[input_save_pretrained],outputs=[output_save_pretrained])
635
+
636
+ btn_init_train.click(fn=init_training,inputs=[],outputs=[input_save_pretrained,output_btn_init_train])
637
 
638
+
639
 
640
+
641
 
642
+ interface.launch()
643
+ print('loadeed')
644