wasmdashai commited on
Commit
9f21f94
·
verified ·
1 Parent(s): 9cca908

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -939
app.py CHANGED
@@ -1,952 +1,102 @@
1
- import spaces
2
-
3
- import gradio as gr
4
- GK=0
5
- from transformers import AutoTokenizer
6
  import torch
 
 
 
 
 
7
  import os
8
- from VitsModelSplit.vits_model2 import VitsModel,get_state_grad_loss
9
- import VitsModelSplit.monotonic_align as monotonic_align
10
  token=os.environ.get("key_")
11
- # import VitsModelSplit.monotonic_align as monotonic_align
12
- from IPython.display import clear_output
13
- from transformers import set_seed
14
- import wandb
15
- import logging
16
- import copy
17
- import torch
18
-
19
- import numpy as np
20
- import torch
21
- from datasets import DatasetDict,Dataset
22
- import os
23
- from VitsModelSplit.vits_model2 import VitsModel,get_state_grad_loss
24
- #from VitsModelSplit.vits_model_only_d import Vits_models_only_decoder
25
- #from VitsModelSplit.vits_model import VitsModel
26
- from VitsModelSplit.PosteriorDecoderModel import PosteriorDecoderModel
27
- from VitsModelSplit.feature_extraction import VitsFeatureExtractor
28
- from transformers import AutoTokenizer, HfArgumentParser, set_seed
29
- from VitsModelSplit.Arguments import DataTrainingArguments, ModelArguments, VITSTrainingArguments
30
- from VitsModelSplit.dataset_features_collector import FeaturesCollectionDataset
31
- from torch.cuda.amp import autocast, GradScaler
32
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
-
34
-
35
- # sgl=get_state_grad_loss(k1=True,#generator=False,
36
- # discriminator=False,
37
- # duration=False
38
- # )
39
- # class model_onxx:
40
- # def __init__(self):
41
- # self.model=None
42
- # self.n_onxx=""
43
- # self.storage_dir = "uploads"
44
- # pass
45
-
46
-
47
-
48
-
49
-
50
- # def download_file(self,file_path):
51
- # ff= gr.File(value=file_path, visible=True)
52
- # file_url = ff.value['url']
53
- # return file_url
54
- # def function_change(self,n_model,token,n_onxx,choice):
55
- # if choice=="decoder":
56
-
57
- # V=self.convert_to_onnx_only_decoder(n_model,token,n_onxx)
58
- # elif choice=="all only decoder":
59
- # V=self.convert_to_onnx_only_decoder(n_model,token,n_onxx)
60
- # else:
61
- # V=self.convert_to_onnx_only_decoder(n_model,token,n_onxx)
62
- # return V
63
-
64
- # def install_model(self,n_model,token,n_onxx):
65
- # self.n_onxx=n_onxx
66
- # self.model= VitsModel.from_pretrained(n_model,token=token)
67
- # return self.model
68
- # def convert_model_decoder_onxx(self,n_model,token,namemodelonxx):
69
- # self.model= VitsModel.from_pretrained(n_model,token=token)
70
- # x=f"/tmp/{namemodelonxx}.onnx"
71
- # return x
72
- # def convert_to_onnx_only_decoder(self,n_model,token,namemodelonxx):
73
- # model=VitsModel.from_pretrained(n_model,token=token)
74
- # x=f"/tmp/{namemodelonxx}.onnx"
75
-
76
- # vocab_size = model.text_encoder.embed_tokens.weight.size(0)
77
- # example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
78
- # torch.onnx.export(
79
- # model, # The model to be exported
80
- # example_input, # Example input for the model
81
- # x,# The filename for the exported ONNX model
82
- # opset_version=11, # Use an appropriate ONNX opset version
83
- # input_names=['input'], # Name of the input layer
84
- # output_names=['output'], # Name of the output layer
85
- # dynamic_axes={
86
- # 'input': {0: 'batch_size', 1: 'sequence_length'}, # Dynamic axes for variable-length inputs
87
- # 'output': {0: 'batch_size'}
88
- # }
89
- # )
90
- # return x
91
-
92
- # def convert_to_onnx_all(self,n_model,token ,namemodelonxx):
93
-
94
- # model=VitsModel.from_pretrained(n_model,token=token)
95
- # x=f"dowload_file/{namemodelonxx}.onnx"
96
-
97
- # vocab_size = model.text_encoder.embed_tokens.weight.size(0)
98
- # example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
99
- # torch.onnx.export(
100
- # model, # The model to be exported
101
- # example_input, # Example input for the model
102
- # x, # The filename for the exported ONNX model
103
- # opset_version=11, # Use an appropriate ONNX opset version
104
- # input_names=['input'], # Name of the input layer
105
- # output_names=['output'], # Name of the output layer
106
- # dynamic_axes={
107
- # 'input': {0: 'batch_size', 1: 'sequence_length'}, # Dynamic axes for variable-length inputs
108
- # 'output': {0: 'batch_size'}
109
- # }
110
- # )
111
- # return x
112
- # def starrt(self):
113
- # #with gr.Blocks() as demo:
114
- # with gr.Row():
115
- # with gr.Column():
116
- # text_n_model=gr.Textbox(label="name model")
117
- # text_n_token=gr.Textbox(label="token")
118
- # text_n_onxx=gr.Textbox(label="name model onxx")
119
- # choice = gr.Dropdown(choices=["decoder", "all anoly decoder", "All"], label="My Dropdown")
120
-
121
- # with gr.Column():
122
-
123
- # btn=gr.Button("convert")
124
- # label=gr.Label("return name model onxx")
125
- # btn.click(self.function_change,[text_n_model,text_n_token,text_n_onxx,choice],[gr.File(label="Download File")])
126
- # #choice.change(fn=function_change, inputs=choice, outputs=label)
127
- # #return demo
128
- # c=model_onxx()
129
-
130
- #3333333333333333333333333333
131
- class OnnxModelConverter:
132
- def __init__(self):
133
- self.model = None
134
- def download_file(self,file_path):
135
- ff= gr.File(value=file_path, visible=True)
136
- file_url = ff.value['url']
137
- return file_url
138
-
139
- def convert(self, model_name, token, onnx_filename, conversion_type):
140
- """
141
- Main function to handle different types of model conversions.
142
-
143
- Args:
144
- model_name (str): Name of the model to convert.
145
- token (str): Access token for loading the model.
146
- onnx_filename (str): Desired filename for the ONNX output.
147
- conversion_type (str): Type of conversion ('decoder', 'only_decoder', or 'full_model').
148
-
149
- Returns:
150
- str: The path to the generated ONNX file.
151
- """
152
- if conversion_type == "decoder":
153
- return self.convert_decoder(model_name, token, onnx_filename)
154
- elif conversion_type == "only_decoder":
155
- return self.convert_only_decoder(model_name, token, onnx_filename)
156
- elif conversion_type == "full_model":
157
- return self.convert_full_model(model_name, token, onnx_filename)
158
- else:
159
- raise ValueError("Invalid conversion type. Choose from 'decoder', 'only_decoder', or 'full_model'.")
160
-
161
- def convert_decoder(self, model_name, token, onnx_filename):
162
- """
163
- Converts only the decoder part of the Vits model to ONNX format.
164
-
165
- Args:
166
- model_name (str): Name of the model to convert.
167
- token (str): Access token for loading the model.
168
- onnx_filename (str): Desired filename for the ONNX output.
169
-
170
- Returns:
171
- str: The path to the generated ONNX file.
172
- """
173
- model = VitsModel.from_pretrained(model_name, token=token)
174
- onnx_file = f"/tmp/{onnx_filename}.onnx"
175
- vocab_size = model.text_encoder.embed_tokens.weight.size(0)
176
- example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
177
-
178
- torch.onnx.export(
179
- model,
180
- example_input,
181
- onnx_file,
182
- opset_version=11,
183
- input_names=['input'],
184
- output_names=['output'],
185
- dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}}
186
- )
187
-
188
- return onnx_file
189
-
190
-
191
- def convert_only_decoder(self, model_name, token, onnx_filename):
192
- """
193
- Converts only the decoder part of the Vits model to ONNX format.
194
-
195
- Args:
196
- model_name (str): Name of the model to convert.
197
- token (str): Access token for loading the model.
198
- onnx_filename (str): Desired filename for the ONNX output.
199
-
200
- Returns:
201
- str: The path to the generated ONNX file.
202
- """
203
- model = Vits_models_only_decoder.from_pretrained(model_name, token=token)
204
- onnx_file = f"/tmp/{onnx_filename}.onnx"
205
-
206
- vocab_size = model.text_encoder.embed_tokens.weight.size(0)
207
- example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
208
-
209
- torch.onnx.export(
210
- model,
211
- example_input,
212
- onnx_file,
213
- opset_version=11,
214
- input_names=['input'],
215
- output_names=['output'],
216
- dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}}
217
- )
218
-
219
- return onnx_file
220
-
221
- def convert_full_model(self, model_name, token, onnx_filename):
222
- """
223
- Converts the full Vits model (including encoder and decoder) to ONNX format.
224
-
225
- Args:
226
- model_name (str): Name of the model to convert.
227
- token (str): Access token for loading the model.
228
- onnx_filename (str): Desired filename for the ONNX output.
229
-
230
- Returns:
231
- str: The path to the generated ONNX file.
232
- """
233
- model = VitsModel.from_pretrained(model_name, token=token)
234
- onnx_file = f"/tmp/{onnx_filename}.onnx"
235
-
236
- vocab_size = model.text_encoder.embed_tokens.weight.size(0)
237
- example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
238
-
239
- torch.onnx.export(
240
- model,
241
- example_input,
242
- onnx_file,
243
- opset_version=11,
244
- input_names=['input'],
245
- output_names=['output'],
246
- dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}}
247
- )
248
-
249
- return onnx_file
250
- def starrt(self):
251
- with gr.Blocks() as demo:
252
- with gr.Row():
253
- with gr.Column():
254
- text_n_model=gr.Textbox(label="name model")
255
- text_n_token=gr.Textbox(label="token")
256
- text_n_onxx=gr.Textbox(label="name model onxx")
257
- choice = gr.Dropdown(choices=["decoder", "only_decoder", "full_model"], label="My Dropdown")
258
-
259
- with gr.Column():
260
-
261
- btn=gr.Button("convert")
262
- label=gr.Label("return name model onxx")
263
- btn.click(self.convert,[text_n_model,text_n_token,text_n_onxx,choice],[gr.File(label="Download File")])
264
- #choice.change(fn=function_change, inputs=choice, outputs=label)
265
- return demo
266
- c=OnnxModelConverter()
267
- ###############################################################
268
- Lst=['input_ids',
269
- 'attention_mask',
270
- 'waveform',
271
- 'labels',
272
- 'labels_attention_mask',
273
- 'mel_scaled_input_features']
274
- def covert_cuda_batch(d):
275
- return d
276
- for key in Lst:
277
- d[key]=d[key].cuda(non_blocking=True)
278
- # for key in d['text_encoder_output']:
279
- # d['text_encoder_output'][key]=d['text_encoder_output'][key].cuda(non_blocking=True)
280
- # for key in d['posterior_encode_output']:
281
- # d['posterior_encode_output'][key]=d['posterior_encode_output'][key].cuda(non_blocking=True)
282
-
283
- return d
284
- def generator_loss(disc_outputs):
285
- total_loss = 0
286
- gen_losses = []
287
- for disc_output in disc_outputs:
288
- disc_output = disc_output
289
- loss = torch.mean((1 - disc_output) ** 2)
290
- gen_losses.append(loss)
291
- total_loss += loss
292
-
293
- return total_loss, gen_losses
294
-
295
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
296
- loss = 0
297
- real_losses = 0
298
- generated_losses = 0
299
- for disc_real, disc_generated in zip(disc_real_outputs, disc_generated_outputs):
300
- real_loss = torch.mean((1 - disc_real) ** 2)
301
- generated_loss = torch.mean(disc_generated**2)
302
- loss += real_loss + generated_loss
303
- real_losses += real_loss
304
- generated_losses += generated_loss
305
-
306
- return loss, real_losses, generated_losses
307
-
308
- def feature_loss(feature_maps_real, feature_maps_generated):
309
- loss = 0
310
- for feature_map_real, feature_map_generated in zip(feature_maps_real, feature_maps_generated):
311
- for real, generated in zip(feature_map_real, feature_map_generated):
312
- real = real.detach()
313
- loss += torch.mean(torch.abs(real - generated))
314
-
315
- return loss * 2
316
-
317
-
318
- def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
319
- """
320
- z_p, logs_q: [b, h, t_t]
321
- m_p, logs_p: [b, h, t_t]
322
- """
323
- z_p = z_p.float()
324
- logs_q = logs_q.float()
325
- m_p = m_p.float()
326
- logs_p = logs_p.float()
327
- z_mask = z_mask.float()
328
-
329
- kl = logs_p - logs_q - 0.5
330
- kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
331
- kl = torch.sum(kl * z_mask)
332
- l = kl / torch.sum(z_mask)
333
- return l
334
- #.............................................
335
- # def kl_loss(prior_latents, posterior_log_variance, prior_means, prior_log_variance, labels_mask):
336
-
337
-
338
- # kl = prior_log_variance - posterior_log_variance - 0.5
339
- # kl += 0.5 * ((prior_latents - prior_means) ** 2) * torch.exp(-2.0 * prior_log_variance)
340
- # kl = torch.sum(kl * labels_mask)
341
- # loss = kl / torch.sum(labels_mask)
342
- # return loss
343
-
344
- def get_state_grad_loss(k1=True,
345
- mel=True,
346
- duration=True,
347
- generator=True,
348
- discriminator=True):
349
- return {'k1':k1,'mel':mel,'duration':duration,'generator':generator,'discriminator':discriminator}
350
-
351
- @spaces.GPU
352
- def clip_grad_value_(parameters, clip_value, norm_type=2):
353
- if isinstance(parameters, torch.Tensor):
354
- parameters = [parameters]
355
- parameters = list(filter(lambda p: p.grad is not None, parameters))
356
- norm_type = float(norm_type)
357
- if clip_value is not None:
358
- clip_value = float(clip_value)
359
-
360
- total_norm = 0
361
- for p in parameters:
362
- param_norm = p.grad.data.norm(norm_type)
363
- total_norm += param_norm.item() ** norm_type
364
- if clip_value is not None:
365
- p.grad.data.clamp_(min=-clip_value, max=clip_value)
366
- total_norm = total_norm ** (1. / norm_type)
367
- return total_norm
368
-
369
- @spaces.GPU
370
- def get_embed_speaker(self,speaker_id):
371
- if self.config.num_speakers > 1 and speaker_id is not None:
372
- if isinstance(speaker_id, int):
373
- speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
374
- elif isinstance(speaker_id, (list, tuple, np.ndarray)):
375
- speaker_id = torch.tensor(speaker_id, device=self.device)
376
-
377
- if not ((0 <= speaker_id).all() and (speaker_id < self.config.num_speakers).all()).item():
378
- raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.")
379
-
380
-
381
- return self.embed_speaker(speaker_id).unsqueeze(-1)
382
- else:
383
- return None
384
-
385
- def get_data_loader(train_dataset_dirs,eval_dataset_dir,full_generation_dir,device):
386
- ctrain_datasets=[]
387
- for dataset_dir ,id_sp in train_dataset_dirs:
388
- train_dataset = FeaturesCollectionDataset(dataset_dir = os.path.join(dataset_dir,'train'),
389
- device = device
390
- )
391
- ctrain_datasets.append((train_dataset,id_sp))
392
-
393
-
394
-
395
-
396
- eval_dataset = None
397
-
398
- eval_dataset = FeaturesCollectionDataset(dataset_dir = eval_dataset_dir,
399
- device = device
400
- )
401
-
402
- full_generation_dataset = FeaturesCollectionDataset(dataset_dir = full_generation_dir,
403
- device = device)
404
- return ctrain_datasets,eval_dataset,full_generation_dataset
405
- global_step=0
406
-
407
-
408
- def train_step(batch,models=[],optimizers=[], training_args=None,tools=[]):
409
- self,discriminator=models
410
- optimizer,disc_optimizer,scaler=optimizers
411
- feature_extractor,maf,dict_state_grad_loss=tools
412
-
413
- with autocast(enabled=training_args.fp16):
414
- speaker_embeddings=get_embed_speaker(self,batch["speaker_id"])
415
- waveform,ids_slice,log_duration,prior_latents,posterior_log_variances,prior_means,prior_log_variances,labels_padding_mask = self.forward_train(
416
- input_ids=batch["input_ids"],
417
- attention_mask=batch["attention_mask"],
418
- labels=batch["labels"],
419
- labels_attention_mask=batch["labels_attention_mask"],
420
- text_encoder_output =None ,
421
- posterior_encode_output=None ,
422
- return_dict=True,
423
- monotonic_alignment_function=maf,
424
- speaker_embeddings=speaker_embeddings
425
-
426
- )
427
- mel_scaled_labels = batch["mel_scaled_input_features"]
428
- mel_scaled_target = self.slice_segments(mel_scaled_labels, ids_slice,self.segment_size)
429
- mel_scaled_generation = feature_extractor._torch_extract_fbank_features(waveform.squeeze(1))[1]
430
-
431
- target_waveform = batch["waveform"].transpose(1, 2)
432
- target_waveform = self.slice_segments(
433
- target_waveform,
434
- ids_slice * feature_extractor.hop_length,
435
- self.config.segment_size
436
- )
437
-
438
- discriminator_target, fmaps_target = discriminator(target_waveform)
439
- discriminator_candidate, fmaps_candidate = discriminator(waveform.detach())
440
- with autocast(enabled=False):
441
- if dict_state_grad_loss['discriminator']:
442
-
443
-
444
- loss_disc, loss_real_disc, loss_fake_disc = discriminator_loss(
445
- discriminator_target, discriminator_candidate
446
- )
447
-
448
- loss_dd = loss_disc# + loss_real_disc + loss_fake_disc
449
-
450
- # loss_dd.backward()
451
-
452
- disc_optimizer.zero_grad()
453
- scaler.scale(loss_dd).backward()
454
- scaler.unscale_(disc_optimizer )
455
- grad_norm_d = clip_grad_value_(discriminator.parameters(), None)
456
- scaler.step(disc_optimizer)
457
- loss_des=grad_norm_d
458
-
459
- with autocast(enabled=training_args.fp16):
460
-
461
- # backpropagate
462
-
463
- discriminator_target, fmaps_target = discriminator(target_waveform)
464
-
465
- discriminator_candidate, fmaps_candidate = discriminator(waveform.detach())
466
- with autocast(enabled=False):
467
- if dict_state_grad_loss['k1']:
468
- loss_kl = kl_loss(
469
- prior_latents,
470
- posterior_log_variances,
471
- prior_means,
472
- prior_log_variances,
473
- labels_padding_mask,
474
- )
475
- loss_kl=loss_kl*training_args.weight_kl
476
- loss_klall=loss_kl.detach().item()
477
- #if displayloss['loss_kl']>=0:
478
- # loss_kl.backward()
479
-
480
- if dict_state_grad_loss['mel']:
481
- loss_mel = torch.nn.functional.l1_loss(mel_scaled_target, mel_scaled_generation)*training_args.weight_mel
482
- loss_melall= loss_mel.detach().item()
483
- # train_losses_sum = train_losses_sum + displayloss['loss_mel']
484
- # if displayloss['loss_mel']>=0:
485
- # loss_mel.backward()
486
-
487
- if dict_state_grad_loss['duration']:
488
- loss_duration=torch.sum(log_duration)*training_args.weight_duration
489
- loss_durationsall=loss_duration.detach().item()
490
- # if displayloss['loss_duration']>=0:
491
- # loss_duration.backward()
492
- if dict_state_grad_loss['generator']:
493
- loss_fmaps = feature_loss(fmaps_target, fmaps_candidate)
494
- loss_gen, losses_gen = generator_loss(discriminator_candidate)
495
- loss_gen=loss_gen * training_args.weight_gen
496
-
497
- # loss_gen.backward(retain_graph=True)
498
- loss_fmaps=loss_fmaps * training_args.weight_fmaps
499
-
500
- # loss_fmaps.backward(retain_graph=True)
501
- total_generator_loss = (
502
- loss_duration
503
- + loss_mel
504
- + loss_kl
505
- + loss_fmaps
506
- + loss_gen
507
- )
508
- # total_generator_loss.backward()
509
- optimizer.zero_grad()
510
- scaler.scale(total_generator_loss).backward()
511
- scaler.unscale_(optimizer)
512
- grad_norm_g = clip_grad_value_(self.parameters(), None)
513
- scaler.step(optimizer)
514
- scaler.update()
515
- loss_gen=grad_norm_g
516
-
517
- return loss_gen,loss_des,loss_durationsall,loss_melall,loss_klall
518
-
519
 
520
 
521
- def train_epoch(obtrainer,index_db=0,epoch=0,idspeakers=[],full_generation_sample_index=-1):
522
- train_losses_sum = 0
523
- loss_genall=0
524
- loss_desall=0
525
- loss_durationsall=0
526
- loss_melall=0
527
- loss_klall=0
528
- loss_fmapsall=0
529
- start_speeker,end_speeker=idspeakers
530
-
531
-
532
- datatrain=obtrainer.DataSets['train'][index_db]
533
- lr_scheduler,disc_lr_scheduler=obtrainer.lr_schedulers
534
- lr_scheduler.step()
535
-
536
- disc_lr_scheduler.step()
537
- train_dataset,speaker_id=datatrain
538
- print(f" Num Epochs = {epoch}, speaker_id DB ={speaker_id}")
539
- num_div_proc=int(len(train_dataset)/10)+1
540
- print(' -process traning : [',end='')
541
- full_generation_sample =obtrainer.DataSets['full_generation'][full_generation_sample_index]
542
-
543
-
544
-
545
- for step, batch in enumerate(train_dataset):
546
- loss_gen,loss_des,loss_durationsa,loss_mela,loss_kl=train_step(batch,
547
- models=obtrainer.models,
548
- optimizers=obtrainer.optimizers,
549
- training_args=obtrainer.training_args,
550
- tools=obtrainer.tools)
551
- loss_genall+=loss_gen
552
- loss_desall+=loss_des
553
- loss_durationsall+=loss_durationsa
554
- loss_melall+=loss_mela
555
- loss_klall+=loss_kl
556
-
557
- obtrainer.global_step +=1
558
- if step%num_div_proc==0:
559
- print('==',end='')
560
-
561
- # validation
562
-
563
- do_eval = obtrainer.training_args.do_eval and (obtrainer.global_step % obtrainer.training_args.eval_steps == 0)
564
-
565
-
566
- if do_eval:
567
- speaker_id_c=int(torch.randint(start_speeker,end_speeker,size=(1,))[0])
568
- model=obtrainer.models[0]
569
-
570
- with torch.no_grad():
571
-
572
- full_generation =model.forward(
573
- input_ids =full_generation_sample["input_ids"],
574
- attention_mask=full_generation_sample["attention_mask"],
575
- speaker_id=speaker_id_c
576
- )
577
-
578
- full_generation_waveform = full_generation.waveform.cpu().numpy()
579
-
580
- wandb.log({
581
- "full generations samples": [
582
- wandb.Audio(w.reshape(-1), caption=f"Full generation sample {epoch}", sample_rate=16000)
583
- for w in full_generation_waveform],})
584
- step+=1
585
- # wandb.log({"train_losses":loss_melall})
586
- wandb.log({"loss_gen":loss_genall/step})
587
- wandb.log({"loss_des":loss_desall/step})
588
- wandb.log({"loss_duration":loss_durationsall/step})
589
- wandb.log({"loss_mel":loss_melall/step})
590
- wandb.log({f"loss_kl_db{speaker_id}":loss_klall/step})
591
- print(']',end='')
592
-
593
-
594
-
595
-
596
-
597
-
598
-
599
-
600
-
601
-
602
- def load_training_args(path):
603
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, VITSTrainingArguments))
604
- json_file = os.path.abspath(path)
605
- model_args, data_args, training_args = parser.parse_json_file(json_file = json_file)
606
- return training_args
607
- def load_tools():
608
- feature_extractor = VitsFeatureExtractor()
609
- dict_state_grad_loss=get_state_grad_loss()
610
- return feature_extractor,monotonic_align.maximum_path,dict_state_grad_loss
611
-
612
-
613
- class TrinerModelVITS:
614
- KC=0
615
- def __init__(self,dir_model="",
616
- path_training_args="",
617
- train_dataset_dirs=[],
618
- eval_dataset_dir="",
619
- full_generation_dir="",
620
- token="",
621
-
622
-
623
- device="cpu"):
624
- self.device=device
625
- self.dir_model=dir_model
626
- self.path_training_args=path_training_args
627
- self.stute_mode=False
628
- self.token=token
629
 
630
- self.load_dataset(train_dataset_dirs,eval_dataset_dir,full_generation_dir)
631
- self.epoch_count=0
632
- self.global_step=0
633
- self.len_dataset=len(self.DataSets['train'])
634
- #self.load_model()
635
- #self.init_wandb()
636
- # self.training_args=load_training_args(self.path_training_args)
637
- # training_args= self.training_args
638
- scaler = GradScaler(enabled=True)
639
- # for disc in self.model.discriminator.discriminators:
640
- # disc.apply_weight_norm()
641
- # self.model.decoder.apply_weight_norm()
642
- # # torch.nn.utils.weight_norm(self.decoder.conv_pre)
643
- # # torch.nn.utils.weight_norm(self.decoder.conv_post)
644
- # for flow in self.model.flow.flows:
645
- # torch.nn.utils.weight_norm(flow.conv_pre)
646
- # torch.nn.utils.weight_norm(flow.conv_post)
647
-
648
- discriminator = self.model.discriminator
649
- self.model.discriminator = None
650
- self.models=(self.model,discriminator)
651
-
652
- optimizer = torch.optim.AdamW(
653
- self.model.parameters(),
654
- 2e-4,
655
- betas=[0.8, 0.99],
656
- # eps=training_args.adam_epsilon,
657
- )
658
-
659
- # Hack to be able to train on multiple device
660
- disc_optimizer = torch.optim.AdamW(
661
- discriminator.parameters(),
662
- 2e-4,
663
- betas=[0.8, 0.99],
664
- # eps=training_args.adam_epsilon,
665
- )
666
- lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
667
- optimizer,gamma=0.999875, last_epoch=-1
668
- )
669
- disc_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
670
- disc_optimizer, gamma=0.999875,last_epoch=-1
671
- )
672
- # self.models=(self.model,discriminator)
673
- self.optimizers=(optimizer,disc_optimizer,scaler)
674
- self.lr_schedulers=(lr_scheduler,disc_lr_scheduler)
675
- self.tools=load_tools()
676
- self.stute_mode=True
677
- print(self.lr_schedulers)
678
-
679
-
680
-
681
-
682
- def init_Starting(self):
683
- print('init_Starting')
684
- #self.training_args=load_training_args(self.path_training_args)
685
- #self.stute_mode=False
686
- print('end training_args')
687
-
688
-
689
- def init_training(self):
690
-
691
-
692
- self.initialize_training_components()
693
- # self.epoch_count=0
694
-
695
-
696
- def load_model(self):
697
- self.model=VitsModel.from_pretrained(self.dir_model,token=self.token).to(self.device)
698
- self.model.setMfA(monotonic_align.maximum_path)
699
-
700
- def init_wandb(self):
701
- wandb.login(key= "782b6a6e82bbb5a5348de0d3c7d40d1e76351e79")
702
- #config = self.training_args.to_dict()
703
- wandb.init(project= 'HugfaceTraining')
704
-
705
- def load_modell(self,namemodel):
706
- self.model=VitsModel.from_pretrained(namemodel,token=self.token).to(self.device)
707
- return "true"
708
- def load_dataset(self,train_dataset_dirs,eval_dataset_dir,full_generation_dir):
709
- ctrain_datasets,eval_dataset,full_generation_dataset=get_data_loader(train_dataset_dirs = train_dataset_dirs,
710
- eval_dataset_dir =eval_dataset_dir ,
711
- full_generation_dir =full_generation_dir ,
712
- device=self.device)
713
- self.DataSets={'train':ctrain_datasets,'eval':eval_dataset,'full_generation':full_generation_dataset}
714
-
715
-
716
 
717
-
718
-
719
-
720
 
721
 
722
- def initialize_training_components(self):
723
-
724
-
725
- self.training_args=load_training_args(self.path_training_args)
726
- training_args= self.training_args
727
- training_args.weight_kl=1
728
- training_args.d_learning_rate=2e-4
729
- training_args.learning_rate=2e-4
730
- training_args.weight_mel=45
731
- training_args.num_train_epochs=4
732
- training_args.eval_steps=1000
733
- training_args.fp16=True
734
-
735
-
736
- set_seed(training_args.seed)
737
- # scaler = GradScaler(enabled=training_args.fp16)
738
 
739
-
740
- # # Initialize optimizer, lr_scheduler
741
- # for disc in self.model.discriminator.discriminators:
742
- # disc.apply_weight_norm()
743
- # self.model.decoder.apply_weight_norm()
744
- # # torch.nn.utils.weight_norm(self.decoder.conv_pre)
745
- # # torch.nn.utils.weight_norm(self.decoder.conv_post)
746
- # for flow in self.model.flow.flows:
747
- # torch.nn.utils.weight_norm(flow.conv_pre)
748
- # torch.nn.utils.weight_norm(flow.conv_post)
749
-
750
- # discriminator = self.model.discriminator
751
- # self.model.discriminator = None
752
- # model,discriminator=self.models
753
-
754
- # optimizer = torch.optim.AdamW(
755
- # model.parameters(),
756
- # training_args.learning_rate,
757
- # betas=[training_args.adam_beta1, training_args.adam_beta2],
758
- # eps=training_args.adam_epsilon,
759
- # )
760
-
761
- # # Hack to be able to train on multiple device
762
- # disc_optimizer = torch.optim.AdamW(
763
- # discriminator.parameters(),
764
- # training_args.d_learning_rate,
765
- # betas=[training_args.d_adam_beta1, training_args.d_adam_beta2],
766
- # eps=training_args.adam_epsilon,
767
- # )
768
- # lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
769
- # optimizer, gamma=training_args.lr_decay, last_epoch=-1
770
- # )
771
- # disc_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
772
- # disc_optimizer, gamma=training_args.lr_decay, last_epoch=-1
773
- # )
774
- # # self.models=(self.model,discriminator)
775
- # self.optimizers=(optimizer,disc_optimizer,scaler)
776
- # self.lr_schedulers=(lr_scheduler,disc_lr_scheduler)
777
- # self.tools=load_tools()
778
- # self.stute_mode=True
779
- # print(self.lr_schedulers)
780
-
781
-
782
-
783
-
784
-
785
- def save_pretrained(self,path_save_model):
786
-
787
- model,discriminator=self.models
788
-
789
- model.discriminator=discriminator
790
- for disc in model.discriminator.discriminators:
791
- disc.remove_weight_norm()
792
- model.decoder.remove_weight_norm()
793
- # torch.nn.utils.remove_weight_norm(self.decoder.conv_pre)
794
- # torch.nn.utils.remove_weight_norm(self.decoder.conv_post)
795
- for flow in model.flow.flows:
796
- torch.nn.utils.remove_weight_norm(flow.conv_pre)
797
- torch.nn.utils.remove_weight_norm(flow.conv_post)
798
-
799
- model.push_to_hub(path_save_model,token=self.token)
800
-
801
-
802
- def run_train_epoch(self):
803
- index_db=self.epoch_count%self.len_dataset
804
- train_epoch(self,index_db=index_db,epoch=self.epoch_count,idspeakers=(0,1),full_generation_sample_index=-1)
805
- self.epoch_count+=1
806
- return f'epoch_count:{self.epoch_count},global_step:{self.global_step},index_db"{index_db}'
807
-
808
-
809
-
810
-
811
-
812
-
813
- # return (self.model,discriminator),(optimizer, disc_optimizer), (lr_scheduler, disc_lr_scheduler)
814
-
815
-
816
-
817
-
818
- # logger.info("***** Training / Inference Done *****")
819
- def modelspeech(texts):
820
-
821
-
822
-
823
- inputs = tokenizer(texts, return_tensors="pt")#.cuda()
824
-
825
- wav = model_vits(input_ids=inputs["input_ids"]).waveform#.detach()
826
- # display(Audio(wav, rate=model.config.sampling_rate))
827
- return model_vits.config.sampling_rate,wav#remove_noise_nr(wav)
828
-
829
- dataset_dir='ABThag-db'
830
- train_dataset_dirs=[
831
- # ('/content/drive/MyDrive/vitsM/DATA/fahd_db',0),
832
- # ('/content/drive/MyDrive/vitsM/DATA/fahd_db',0),
833
- # ('/content/drive/MyDrive/vitsM/DB2KKKK',1),
834
- # ('/content/drive/MyDrive/vitsM/DATA/Db_Amgd_50_Bitch10',0),
835
- # ('/content/drive/MyDrive/vitsM/DB2KKKK',1), #
836
- # ('/content/drive/MyDrive/vitsM/DATA/Db_Amgd_50_Bitch10',0),
837
- # ('/content/drive/MyDrive/vitsM/DATA/DBWfaa-Bitch:8-Count:60',0),
838
- # ('/content/drive/MyDrive/vitsM/DATA/Wafa/b10r',0),
839
- # ('/content/drive/MyDrive/vitsM/DATA/Wafa/b16r',0),
840
- # ('/content/drive/MyDrive/vitsM/DATA/Wafa/b4',0),
841
-
842
- # ('/content/drive/MyDrive/vitsM/DATA/fahd_db',None),
843
- # ('/content/drive/MyDrive/vitsM/DATA/wafa-db',None),
844
- # ('/content/drive/MyDrive/vitsM/DATA/wafa-db',4),
845
- # ('/content/drive/MyDrive/vitsM/DATA/DB-ABThag-Bitch:5-Count-37',4),
846
- # ('/content/drive/MyDrive/vitsM/DB-300-k',6),
847
- ('databatchs',0),
848
- #('/content/drive/MyDrive/dataset_ljBatchs',0),
849
-
850
-
851
-
852
-
853
-
854
- ]
855
-
856
-
857
-
858
-
859
-
860
- dir_model='wasmdashai/vits-ar-huba-fine'
861
- pro=TrinerModelVITS(dir_model=dir_model,
862
- path_training_args='VitsModelSplit/finetune_config_ara.json',
863
- train_dataset_dirs = train_dataset_dirs,
864
- eval_dataset_dir = os.path.join(dataset_dir,'eval'),
865
- full_generation_dir = os.path.join(dataset_dir,'full_generation'),
866
- token=token,
867
- device=device
868
- )
869
- def loadd_d():
870
- token=os.environ.get("key_")
871
- #model=VitsModel.from_pretrained(n_model,token=token)
872
- return token
873
- @spaces.GPU(duration=30)
874
- def run_train_epoch(num):
875
- TrinerModelVITS.KC+=1
876
- if num >0:
877
- pro.init_training()
878
- for i in range(num):
879
- # model.train(True)
880
- return pro.run_train_epoch() +f'- kc={TrinerModelVITS.KC}'
881
- else:
882
- pro.save_pretrained(pro.dir_model)
883
- pro.load_model()
884
- return 'save model '
885
-
886
- @spaces.GPU
887
- def init_training():
888
- pro.init_training()
889
- return pro.dir_model,'init_training'
890
-
891
- @spaces.GPU
892
- def init_Starting():
893
-
894
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
895
-
896
 
897
- return 'init_Starting'
898
- @spaces.GPU
899
- def init_wandb():
900
- pro.init_wandb()
901
- return 'init_wandb'
902
-
903
- def save_pretrained(path):
904
- pro.save_pretrained(path)
905
 
906
- pro.load_model()
907
- return 'save_pretrained'
908
- def read_modell(n_model):
909
- #model22=Vits_models_only_decoder.from_pretrained(n_model,token)#.to("cuda")
910
- return token
911
- with gr.Blocks() as interface:
912
- with gr.Accordion("get token", open=False):
913
- btn_init = gr.Button("run")
914
- label=gr.Label("hhh")
915
- btn_init.click(loadd_d,[],[label])
916
- with gr.Accordion("onxx ", open=False):
917
- c.starrt()
918
- with gr.Accordion("init_Starting ", open=False):
919
- btn_init = gr.Button("init start")
920
- output_init = gr.Textbox(label="init")
921
- btn_init.click(fn=init_Starting,inputs=[],outputs=[output_init])
922
- with gr.Accordion("init_wandb ", open=False):
923
- btn_init_wandb = gr.Button("nit_wandb")
924
- output_initbtn_init_wandb = gr.Textbox(label="init")
925
- btn_init_wandb.click(fn=init_wandb,inputs=[],outputs=[output_initbtn_init_wandb])
926
-
927
- with gr.Accordion("init_training ", open=False):
928
- btn_init_train = gr.Button("init init_train")
929
- output_btn_init_train = gr.Textbox(label="init")
930
- # btn_init_train.click(fn=init_training,inputs=[],outputs=[output_btn_init_train])
931
-
932
- with gr.Accordion("run_train_epoch ", open=False):
933
- btn_run_train_epoch = gr.Button("run_train_epoch")
934
- input_run_train_epoch = gr.Number(label="number _train_epoch")
935
- output_run_train_epoch = gr.Textbox(label="run_train_epoch")
936
- btn_run_train_epoch.click(fn=run_train_epoch,inputs=[input_run_train_epoch],outputs=[output_run_train_epoch])
937
-
938
- with gr.Accordion("save_pretrained ", open=False):
939
- btn_save_pretrained = gr.Button("save_pretrained")
940
- input_save_pretrained = gr.Textbox(label="save_pretrained")
941
- output_save_pretrained = gr.Textbox(label="save_pretrained")
942
- btn_save_pretrained.click(fn=save_pretrained,inputs=[input_save_pretrained],outputs=[output_save_pretrained])
943
-
944
- btn_init_train.click(fn=init_training,inputs=[],outputs=[input_save_pretrained,output_btn_init_train])
945
 
946
-
947
-
948
-
949
-
950
- interface.launch()
951
- print('loadeed')
952
-
 
1
+ from transformers import MllamaForConditionalGeneration, AutoProcessor, TextIteratorStreamer
2
+ from PIL import Image
3
+ import requests
 
 
4
  import torch
5
+ from threading import Thread
6
+ import gradio as gr
7
+ from gradio import FileData
8
+ import time
9
+ import spaces
10
  import os
 
 
11
  token=os.environ.get("key_")
12
+ ckpt = "meta-llama/Llama-3.2-11B-Vision-Instruct"
13
+ model = MllamaForConditionalGeneration.from_pretrained(ckpt,token=token,
14
+ torch_dtype=torch.bfloat16).to("cuda")
15
+ processor = AutoProcessor.from_pretrained(ckpt,token=token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
+ #@spaces.GPU
19
+ def bot_streaming(message, history, max_new_tokens=250):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ txt = message["text"]
22
+ ext_buffer = f"{txt}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ messages= []
25
+ images = []
 
26
 
27
 
28
+ for i, msg in enumerate(history):
29
+ if isinstance(msg[0], tuple):
30
+ messages.append({"role": "user", "content": [{"type": "text", "text": history[i+1][0]}, {"type": "image"}]})
31
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": history[i+1][1]}]})
32
+ images.append(Image.open(msg[0][0]).convert("RGB"))
33
+ elif isinstance(history[i-1], tuple) and isinstance(msg[0], str):
34
+ # messages are already handled
35
+ pass
36
+ elif isinstance(history[i-1][0], str) and isinstance(msg[0], str): # text only turn
37
+ messages.append({"role": "user", "content": [{"type": "text", "text": msg[0]}]})
38
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": msg[1]}]})
39
+
40
+ # add current message
41
+ if len(message["files"]) == 1:
 
 
42
 
43
+ if isinstance(message["files"][0], str): # examples
44
+ image = Image.open(message["files"][0]).convert("RGB")
45
+ else: # regular input
46
+ image = Image.open(message["files"][0]["path"]).convert("RGB")
47
+ images.append(image)
48
+ messages.append({"role": "user", "content": [{"type": "text", "text": txt}, {"type": "image"}]})
49
+ else:
50
+ messages.append({"role": "user", "content": [{"type": "text", "text": txt}]})
51
+
52
+
53
+ texts = processor.apply_chat_template(messages, add_generation_prompt=True)
54
+
55
+ if images == []:
56
+ inputs = processor(text=texts, return_tensors="pt").to("cuda")
57
+ else:
58
+ inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda")
59
+ streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
60
+
61
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
62
+ generated_text = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
65
+ thread.start()
66
+ buffer = ""
 
 
 
 
 
67
 
68
+ for new_text in streamer:
69
+ buffer += new_text
70
+ generated_text_without_prompt = buffer
71
+ time.sleep(0.01)
72
+ yield buffer
73
+
74
+
75
+ demo = gr.ChatInterface(fn=bot_streaming, title="Multimodal Llama", examples=[
76
+ [{"text": "Which era does this piece belong to? Give details about the era.", "files":["./examples/rococo.jpg"]},
77
+ 200],
78
+ [{"text": "Where do the droughts happen according to this diagram?", "files":["./examples/weather_events.png"]},
79
+ 250],
80
+ [{"text": "What happens when you take out white cat from this chain?", "files":["./examples/ai2d_test.jpg"]},
81
+ 250],
82
+ [{"text": "How long does it take from invoice date to due date? Be short and concise.", "files":["./examples/invoice.png"]},
83
+ 250],
84
+ [{"text": "Where to find this monument? Can you give me other recommendations around the area?", "files":["./examples/wat_arun.jpg"]},
85
+ 250],
86
+ ],
87
+ textbox=gr.MultimodalTextbox(),
88
+ additional_inputs = [gr.Slider(
89
+ minimum=10,
90
+ maximum=500,
91
+ value=250,
92
+ step=10,
93
+ label="Maximum number of new tokens to generate",
94
+ )
95
+ ],
96
+ cache_examples=False,
97
+ description="Try Multimodal Llama by Meta with transformers in this demo. Upload an image, and start chatting about it, or simply try one of the examples below. To learn more about Llama Vision, visit [our blog post](https://huggingface.co/blog/llama32). ",
98
+ stop_btn="Stop Generation",
99
+ fill_height=True,
100
+ multimodal=True)
 
 
 
 
 
 
101
 
102
+ demo.launch(debug=True)