Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,952 +1,102 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
import
|
4 |
-
GK=0
|
5 |
-
from transformers import AutoTokenizer
|
6 |
import torch
|
|
|
|
|
|
|
|
|
|
|
7 |
import os
|
8 |
-
from VitsModelSplit.vits_model2 import VitsModel,get_state_grad_loss
|
9 |
-
import VitsModelSplit.monotonic_align as monotonic_align
|
10 |
token=os.environ.get("key_")
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
import logging
|
16 |
-
import copy
|
17 |
-
import torch
|
18 |
-
|
19 |
-
import numpy as np
|
20 |
-
import torch
|
21 |
-
from datasets import DatasetDict,Dataset
|
22 |
-
import os
|
23 |
-
from VitsModelSplit.vits_model2 import VitsModel,get_state_grad_loss
|
24 |
-
#from VitsModelSplit.vits_model_only_d import Vits_models_only_decoder
|
25 |
-
#from VitsModelSplit.vits_model import VitsModel
|
26 |
-
from VitsModelSplit.PosteriorDecoderModel import PosteriorDecoderModel
|
27 |
-
from VitsModelSplit.feature_extraction import VitsFeatureExtractor
|
28 |
-
from transformers import AutoTokenizer, HfArgumentParser, set_seed
|
29 |
-
from VitsModelSplit.Arguments import DataTrainingArguments, ModelArguments, VITSTrainingArguments
|
30 |
-
from VitsModelSplit.dataset_features_collector import FeaturesCollectionDataset
|
31 |
-
from torch.cuda.amp import autocast, GradScaler
|
32 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
33 |
-
|
34 |
-
|
35 |
-
# sgl=get_state_grad_loss(k1=True,#generator=False,
|
36 |
-
# discriminator=False,
|
37 |
-
# duration=False
|
38 |
-
# )
|
39 |
-
# class model_onxx:
|
40 |
-
# def __init__(self):
|
41 |
-
# self.model=None
|
42 |
-
# self.n_onxx=""
|
43 |
-
# self.storage_dir = "uploads"
|
44 |
-
# pass
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
# def download_file(self,file_path):
|
51 |
-
# ff= gr.File(value=file_path, visible=True)
|
52 |
-
# file_url = ff.value['url']
|
53 |
-
# return file_url
|
54 |
-
# def function_change(self,n_model,token,n_onxx,choice):
|
55 |
-
# if choice=="decoder":
|
56 |
-
|
57 |
-
# V=self.convert_to_onnx_only_decoder(n_model,token,n_onxx)
|
58 |
-
# elif choice=="all only decoder":
|
59 |
-
# V=self.convert_to_onnx_only_decoder(n_model,token,n_onxx)
|
60 |
-
# else:
|
61 |
-
# V=self.convert_to_onnx_only_decoder(n_model,token,n_onxx)
|
62 |
-
# return V
|
63 |
-
|
64 |
-
# def install_model(self,n_model,token,n_onxx):
|
65 |
-
# self.n_onxx=n_onxx
|
66 |
-
# self.model= VitsModel.from_pretrained(n_model,token=token)
|
67 |
-
# return self.model
|
68 |
-
# def convert_model_decoder_onxx(self,n_model,token,namemodelonxx):
|
69 |
-
# self.model= VitsModel.from_pretrained(n_model,token=token)
|
70 |
-
# x=f"/tmp/{namemodelonxx}.onnx"
|
71 |
-
# return x
|
72 |
-
# def convert_to_onnx_only_decoder(self,n_model,token,namemodelonxx):
|
73 |
-
# model=VitsModel.from_pretrained(n_model,token=token)
|
74 |
-
# x=f"/tmp/{namemodelonxx}.onnx"
|
75 |
-
|
76 |
-
# vocab_size = model.text_encoder.embed_tokens.weight.size(0)
|
77 |
-
# example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
|
78 |
-
# torch.onnx.export(
|
79 |
-
# model, # The model to be exported
|
80 |
-
# example_input, # Example input for the model
|
81 |
-
# x,# The filename for the exported ONNX model
|
82 |
-
# opset_version=11, # Use an appropriate ONNX opset version
|
83 |
-
# input_names=['input'], # Name of the input layer
|
84 |
-
# output_names=['output'], # Name of the output layer
|
85 |
-
# dynamic_axes={
|
86 |
-
# 'input': {0: 'batch_size', 1: 'sequence_length'}, # Dynamic axes for variable-length inputs
|
87 |
-
# 'output': {0: 'batch_size'}
|
88 |
-
# }
|
89 |
-
# )
|
90 |
-
# return x
|
91 |
-
|
92 |
-
# def convert_to_onnx_all(self,n_model,token ,namemodelonxx):
|
93 |
-
|
94 |
-
# model=VitsModel.from_pretrained(n_model,token=token)
|
95 |
-
# x=f"dowload_file/{namemodelonxx}.onnx"
|
96 |
-
|
97 |
-
# vocab_size = model.text_encoder.embed_tokens.weight.size(0)
|
98 |
-
# example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
|
99 |
-
# torch.onnx.export(
|
100 |
-
# model, # The model to be exported
|
101 |
-
# example_input, # Example input for the model
|
102 |
-
# x, # The filename for the exported ONNX model
|
103 |
-
# opset_version=11, # Use an appropriate ONNX opset version
|
104 |
-
# input_names=['input'], # Name of the input layer
|
105 |
-
# output_names=['output'], # Name of the output layer
|
106 |
-
# dynamic_axes={
|
107 |
-
# 'input': {0: 'batch_size', 1: 'sequence_length'}, # Dynamic axes for variable-length inputs
|
108 |
-
# 'output': {0: 'batch_size'}
|
109 |
-
# }
|
110 |
-
# )
|
111 |
-
# return x
|
112 |
-
# def starrt(self):
|
113 |
-
# #with gr.Blocks() as demo:
|
114 |
-
# with gr.Row():
|
115 |
-
# with gr.Column():
|
116 |
-
# text_n_model=gr.Textbox(label="name model")
|
117 |
-
# text_n_token=gr.Textbox(label="token")
|
118 |
-
# text_n_onxx=gr.Textbox(label="name model onxx")
|
119 |
-
# choice = gr.Dropdown(choices=["decoder", "all anoly decoder", "All"], label="My Dropdown")
|
120 |
-
|
121 |
-
# with gr.Column():
|
122 |
-
|
123 |
-
# btn=gr.Button("convert")
|
124 |
-
# label=gr.Label("return name model onxx")
|
125 |
-
# btn.click(self.function_change,[text_n_model,text_n_token,text_n_onxx,choice],[gr.File(label="Download File")])
|
126 |
-
# #choice.change(fn=function_change, inputs=choice, outputs=label)
|
127 |
-
# #return demo
|
128 |
-
# c=model_onxx()
|
129 |
-
|
130 |
-
#3333333333333333333333333333
|
131 |
-
class OnnxModelConverter:
|
132 |
-
def __init__(self):
|
133 |
-
self.model = None
|
134 |
-
def download_file(self,file_path):
|
135 |
-
ff= gr.File(value=file_path, visible=True)
|
136 |
-
file_url = ff.value['url']
|
137 |
-
return file_url
|
138 |
-
|
139 |
-
def convert(self, model_name, token, onnx_filename, conversion_type):
|
140 |
-
"""
|
141 |
-
Main function to handle different types of model conversions.
|
142 |
-
|
143 |
-
Args:
|
144 |
-
model_name (str): Name of the model to convert.
|
145 |
-
token (str): Access token for loading the model.
|
146 |
-
onnx_filename (str): Desired filename for the ONNX output.
|
147 |
-
conversion_type (str): Type of conversion ('decoder', 'only_decoder', or 'full_model').
|
148 |
-
|
149 |
-
Returns:
|
150 |
-
str: The path to the generated ONNX file.
|
151 |
-
"""
|
152 |
-
if conversion_type == "decoder":
|
153 |
-
return self.convert_decoder(model_name, token, onnx_filename)
|
154 |
-
elif conversion_type == "only_decoder":
|
155 |
-
return self.convert_only_decoder(model_name, token, onnx_filename)
|
156 |
-
elif conversion_type == "full_model":
|
157 |
-
return self.convert_full_model(model_name, token, onnx_filename)
|
158 |
-
else:
|
159 |
-
raise ValueError("Invalid conversion type. Choose from 'decoder', 'only_decoder', or 'full_model'.")
|
160 |
-
|
161 |
-
def convert_decoder(self, model_name, token, onnx_filename):
|
162 |
-
"""
|
163 |
-
Converts only the decoder part of the Vits model to ONNX format.
|
164 |
-
|
165 |
-
Args:
|
166 |
-
model_name (str): Name of the model to convert.
|
167 |
-
token (str): Access token for loading the model.
|
168 |
-
onnx_filename (str): Desired filename for the ONNX output.
|
169 |
-
|
170 |
-
Returns:
|
171 |
-
str: The path to the generated ONNX file.
|
172 |
-
"""
|
173 |
-
model = VitsModel.from_pretrained(model_name, token=token)
|
174 |
-
onnx_file = f"/tmp/{onnx_filename}.onnx"
|
175 |
-
vocab_size = model.text_encoder.embed_tokens.weight.size(0)
|
176 |
-
example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
|
177 |
-
|
178 |
-
torch.onnx.export(
|
179 |
-
model,
|
180 |
-
example_input,
|
181 |
-
onnx_file,
|
182 |
-
opset_version=11,
|
183 |
-
input_names=['input'],
|
184 |
-
output_names=['output'],
|
185 |
-
dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}}
|
186 |
-
)
|
187 |
-
|
188 |
-
return onnx_file
|
189 |
-
|
190 |
-
|
191 |
-
def convert_only_decoder(self, model_name, token, onnx_filename):
|
192 |
-
"""
|
193 |
-
Converts only the decoder part of the Vits model to ONNX format.
|
194 |
-
|
195 |
-
Args:
|
196 |
-
model_name (str): Name of the model to convert.
|
197 |
-
token (str): Access token for loading the model.
|
198 |
-
onnx_filename (str): Desired filename for the ONNX output.
|
199 |
-
|
200 |
-
Returns:
|
201 |
-
str: The path to the generated ONNX file.
|
202 |
-
"""
|
203 |
-
model = Vits_models_only_decoder.from_pretrained(model_name, token=token)
|
204 |
-
onnx_file = f"/tmp/{onnx_filename}.onnx"
|
205 |
-
|
206 |
-
vocab_size = model.text_encoder.embed_tokens.weight.size(0)
|
207 |
-
example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
|
208 |
-
|
209 |
-
torch.onnx.export(
|
210 |
-
model,
|
211 |
-
example_input,
|
212 |
-
onnx_file,
|
213 |
-
opset_version=11,
|
214 |
-
input_names=['input'],
|
215 |
-
output_names=['output'],
|
216 |
-
dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}}
|
217 |
-
)
|
218 |
-
|
219 |
-
return onnx_file
|
220 |
-
|
221 |
-
def convert_full_model(self, model_name, token, onnx_filename):
|
222 |
-
"""
|
223 |
-
Converts the full Vits model (including encoder and decoder) to ONNX format.
|
224 |
-
|
225 |
-
Args:
|
226 |
-
model_name (str): Name of the model to convert.
|
227 |
-
token (str): Access token for loading the model.
|
228 |
-
onnx_filename (str): Desired filename for the ONNX output.
|
229 |
-
|
230 |
-
Returns:
|
231 |
-
str: The path to the generated ONNX file.
|
232 |
-
"""
|
233 |
-
model = VitsModel.from_pretrained(model_name, token=token)
|
234 |
-
onnx_file = f"/tmp/{onnx_filename}.onnx"
|
235 |
-
|
236 |
-
vocab_size = model.text_encoder.embed_tokens.weight.size(0)
|
237 |
-
example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
|
238 |
-
|
239 |
-
torch.onnx.export(
|
240 |
-
model,
|
241 |
-
example_input,
|
242 |
-
onnx_file,
|
243 |
-
opset_version=11,
|
244 |
-
input_names=['input'],
|
245 |
-
output_names=['output'],
|
246 |
-
dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}}
|
247 |
-
)
|
248 |
-
|
249 |
-
return onnx_file
|
250 |
-
def starrt(self):
|
251 |
-
with gr.Blocks() as demo:
|
252 |
-
with gr.Row():
|
253 |
-
with gr.Column():
|
254 |
-
text_n_model=gr.Textbox(label="name model")
|
255 |
-
text_n_token=gr.Textbox(label="token")
|
256 |
-
text_n_onxx=gr.Textbox(label="name model onxx")
|
257 |
-
choice = gr.Dropdown(choices=["decoder", "only_decoder", "full_model"], label="My Dropdown")
|
258 |
-
|
259 |
-
with gr.Column():
|
260 |
-
|
261 |
-
btn=gr.Button("convert")
|
262 |
-
label=gr.Label("return name model onxx")
|
263 |
-
btn.click(self.convert,[text_n_model,text_n_token,text_n_onxx,choice],[gr.File(label="Download File")])
|
264 |
-
#choice.change(fn=function_change, inputs=choice, outputs=label)
|
265 |
-
return demo
|
266 |
-
c=OnnxModelConverter()
|
267 |
-
###############################################################
|
268 |
-
Lst=['input_ids',
|
269 |
-
'attention_mask',
|
270 |
-
'waveform',
|
271 |
-
'labels',
|
272 |
-
'labels_attention_mask',
|
273 |
-
'mel_scaled_input_features']
|
274 |
-
def covert_cuda_batch(d):
|
275 |
-
return d
|
276 |
-
for key in Lst:
|
277 |
-
d[key]=d[key].cuda(non_blocking=True)
|
278 |
-
# for key in d['text_encoder_output']:
|
279 |
-
# d['text_encoder_output'][key]=d['text_encoder_output'][key].cuda(non_blocking=True)
|
280 |
-
# for key in d['posterior_encode_output']:
|
281 |
-
# d['posterior_encode_output'][key]=d['posterior_encode_output'][key].cuda(non_blocking=True)
|
282 |
-
|
283 |
-
return d
|
284 |
-
def generator_loss(disc_outputs):
|
285 |
-
total_loss = 0
|
286 |
-
gen_losses = []
|
287 |
-
for disc_output in disc_outputs:
|
288 |
-
disc_output = disc_output
|
289 |
-
loss = torch.mean((1 - disc_output) ** 2)
|
290 |
-
gen_losses.append(loss)
|
291 |
-
total_loss += loss
|
292 |
-
|
293 |
-
return total_loss, gen_losses
|
294 |
-
|
295 |
-
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
296 |
-
loss = 0
|
297 |
-
real_losses = 0
|
298 |
-
generated_losses = 0
|
299 |
-
for disc_real, disc_generated in zip(disc_real_outputs, disc_generated_outputs):
|
300 |
-
real_loss = torch.mean((1 - disc_real) ** 2)
|
301 |
-
generated_loss = torch.mean(disc_generated**2)
|
302 |
-
loss += real_loss + generated_loss
|
303 |
-
real_losses += real_loss
|
304 |
-
generated_losses += generated_loss
|
305 |
-
|
306 |
-
return loss, real_losses, generated_losses
|
307 |
-
|
308 |
-
def feature_loss(feature_maps_real, feature_maps_generated):
|
309 |
-
loss = 0
|
310 |
-
for feature_map_real, feature_map_generated in zip(feature_maps_real, feature_maps_generated):
|
311 |
-
for real, generated in zip(feature_map_real, feature_map_generated):
|
312 |
-
real = real.detach()
|
313 |
-
loss += torch.mean(torch.abs(real - generated))
|
314 |
-
|
315 |
-
return loss * 2
|
316 |
-
|
317 |
-
|
318 |
-
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
319 |
-
"""
|
320 |
-
z_p, logs_q: [b, h, t_t]
|
321 |
-
m_p, logs_p: [b, h, t_t]
|
322 |
-
"""
|
323 |
-
z_p = z_p.float()
|
324 |
-
logs_q = logs_q.float()
|
325 |
-
m_p = m_p.float()
|
326 |
-
logs_p = logs_p.float()
|
327 |
-
z_mask = z_mask.float()
|
328 |
-
|
329 |
-
kl = logs_p - logs_q - 0.5
|
330 |
-
kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
|
331 |
-
kl = torch.sum(kl * z_mask)
|
332 |
-
l = kl / torch.sum(z_mask)
|
333 |
-
return l
|
334 |
-
#.............................................
|
335 |
-
# def kl_loss(prior_latents, posterior_log_variance, prior_means, prior_log_variance, labels_mask):
|
336 |
-
|
337 |
-
|
338 |
-
# kl = prior_log_variance - posterior_log_variance - 0.5
|
339 |
-
# kl += 0.5 * ((prior_latents - prior_means) ** 2) * torch.exp(-2.0 * prior_log_variance)
|
340 |
-
# kl = torch.sum(kl * labels_mask)
|
341 |
-
# loss = kl / torch.sum(labels_mask)
|
342 |
-
# return loss
|
343 |
-
|
344 |
-
def get_state_grad_loss(k1=True,
|
345 |
-
mel=True,
|
346 |
-
duration=True,
|
347 |
-
generator=True,
|
348 |
-
discriminator=True):
|
349 |
-
return {'k1':k1,'mel':mel,'duration':duration,'generator':generator,'discriminator':discriminator}
|
350 |
-
|
351 |
-
@spaces.GPU
|
352 |
-
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
353 |
-
if isinstance(parameters, torch.Tensor):
|
354 |
-
parameters = [parameters]
|
355 |
-
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
356 |
-
norm_type = float(norm_type)
|
357 |
-
if clip_value is not None:
|
358 |
-
clip_value = float(clip_value)
|
359 |
-
|
360 |
-
total_norm = 0
|
361 |
-
for p in parameters:
|
362 |
-
param_norm = p.grad.data.norm(norm_type)
|
363 |
-
total_norm += param_norm.item() ** norm_type
|
364 |
-
if clip_value is not None:
|
365 |
-
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
366 |
-
total_norm = total_norm ** (1. / norm_type)
|
367 |
-
return total_norm
|
368 |
-
|
369 |
-
@spaces.GPU
|
370 |
-
def get_embed_speaker(self,speaker_id):
|
371 |
-
if self.config.num_speakers > 1 and speaker_id is not None:
|
372 |
-
if isinstance(speaker_id, int):
|
373 |
-
speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
|
374 |
-
elif isinstance(speaker_id, (list, tuple, np.ndarray)):
|
375 |
-
speaker_id = torch.tensor(speaker_id, device=self.device)
|
376 |
-
|
377 |
-
if not ((0 <= speaker_id).all() and (speaker_id < self.config.num_speakers).all()).item():
|
378 |
-
raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.")
|
379 |
-
|
380 |
-
|
381 |
-
return self.embed_speaker(speaker_id).unsqueeze(-1)
|
382 |
-
else:
|
383 |
-
return None
|
384 |
-
|
385 |
-
def get_data_loader(train_dataset_dirs,eval_dataset_dir,full_generation_dir,device):
|
386 |
-
ctrain_datasets=[]
|
387 |
-
for dataset_dir ,id_sp in train_dataset_dirs:
|
388 |
-
train_dataset = FeaturesCollectionDataset(dataset_dir = os.path.join(dataset_dir,'train'),
|
389 |
-
device = device
|
390 |
-
)
|
391 |
-
ctrain_datasets.append((train_dataset,id_sp))
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
eval_dataset = None
|
397 |
-
|
398 |
-
eval_dataset = FeaturesCollectionDataset(dataset_dir = eval_dataset_dir,
|
399 |
-
device = device
|
400 |
-
)
|
401 |
-
|
402 |
-
full_generation_dataset = FeaturesCollectionDataset(dataset_dir = full_generation_dir,
|
403 |
-
device = device)
|
404 |
-
return ctrain_datasets,eval_dataset,full_generation_dataset
|
405 |
-
global_step=0
|
406 |
-
|
407 |
-
|
408 |
-
def train_step(batch,models=[],optimizers=[], training_args=None,tools=[]):
|
409 |
-
self,discriminator=models
|
410 |
-
optimizer,disc_optimizer,scaler=optimizers
|
411 |
-
feature_extractor,maf,dict_state_grad_loss=tools
|
412 |
-
|
413 |
-
with autocast(enabled=training_args.fp16):
|
414 |
-
speaker_embeddings=get_embed_speaker(self,batch["speaker_id"])
|
415 |
-
waveform,ids_slice,log_duration,prior_latents,posterior_log_variances,prior_means,prior_log_variances,labels_padding_mask = self.forward_train(
|
416 |
-
input_ids=batch["input_ids"],
|
417 |
-
attention_mask=batch["attention_mask"],
|
418 |
-
labels=batch["labels"],
|
419 |
-
labels_attention_mask=batch["labels_attention_mask"],
|
420 |
-
text_encoder_output =None ,
|
421 |
-
posterior_encode_output=None ,
|
422 |
-
return_dict=True,
|
423 |
-
monotonic_alignment_function=maf,
|
424 |
-
speaker_embeddings=speaker_embeddings
|
425 |
-
|
426 |
-
)
|
427 |
-
mel_scaled_labels = batch["mel_scaled_input_features"]
|
428 |
-
mel_scaled_target = self.slice_segments(mel_scaled_labels, ids_slice,self.segment_size)
|
429 |
-
mel_scaled_generation = feature_extractor._torch_extract_fbank_features(waveform.squeeze(1))[1]
|
430 |
-
|
431 |
-
target_waveform = batch["waveform"].transpose(1, 2)
|
432 |
-
target_waveform = self.slice_segments(
|
433 |
-
target_waveform,
|
434 |
-
ids_slice * feature_extractor.hop_length,
|
435 |
-
self.config.segment_size
|
436 |
-
)
|
437 |
-
|
438 |
-
discriminator_target, fmaps_target = discriminator(target_waveform)
|
439 |
-
discriminator_candidate, fmaps_candidate = discriminator(waveform.detach())
|
440 |
-
with autocast(enabled=False):
|
441 |
-
if dict_state_grad_loss['discriminator']:
|
442 |
-
|
443 |
-
|
444 |
-
loss_disc, loss_real_disc, loss_fake_disc = discriminator_loss(
|
445 |
-
discriminator_target, discriminator_candidate
|
446 |
-
)
|
447 |
-
|
448 |
-
loss_dd = loss_disc# + loss_real_disc + loss_fake_disc
|
449 |
-
|
450 |
-
# loss_dd.backward()
|
451 |
-
|
452 |
-
disc_optimizer.zero_grad()
|
453 |
-
scaler.scale(loss_dd).backward()
|
454 |
-
scaler.unscale_(disc_optimizer )
|
455 |
-
grad_norm_d = clip_grad_value_(discriminator.parameters(), None)
|
456 |
-
scaler.step(disc_optimizer)
|
457 |
-
loss_des=grad_norm_d
|
458 |
-
|
459 |
-
with autocast(enabled=training_args.fp16):
|
460 |
-
|
461 |
-
# backpropagate
|
462 |
-
|
463 |
-
discriminator_target, fmaps_target = discriminator(target_waveform)
|
464 |
-
|
465 |
-
discriminator_candidate, fmaps_candidate = discriminator(waveform.detach())
|
466 |
-
with autocast(enabled=False):
|
467 |
-
if dict_state_grad_loss['k1']:
|
468 |
-
loss_kl = kl_loss(
|
469 |
-
prior_latents,
|
470 |
-
posterior_log_variances,
|
471 |
-
prior_means,
|
472 |
-
prior_log_variances,
|
473 |
-
labels_padding_mask,
|
474 |
-
)
|
475 |
-
loss_kl=loss_kl*training_args.weight_kl
|
476 |
-
loss_klall=loss_kl.detach().item()
|
477 |
-
#if displayloss['loss_kl']>=0:
|
478 |
-
# loss_kl.backward()
|
479 |
-
|
480 |
-
if dict_state_grad_loss['mel']:
|
481 |
-
loss_mel = torch.nn.functional.l1_loss(mel_scaled_target, mel_scaled_generation)*training_args.weight_mel
|
482 |
-
loss_melall= loss_mel.detach().item()
|
483 |
-
# train_losses_sum = train_losses_sum + displayloss['loss_mel']
|
484 |
-
# if displayloss['loss_mel']>=0:
|
485 |
-
# loss_mel.backward()
|
486 |
-
|
487 |
-
if dict_state_grad_loss['duration']:
|
488 |
-
loss_duration=torch.sum(log_duration)*training_args.weight_duration
|
489 |
-
loss_durationsall=loss_duration.detach().item()
|
490 |
-
# if displayloss['loss_duration']>=0:
|
491 |
-
# loss_duration.backward()
|
492 |
-
if dict_state_grad_loss['generator']:
|
493 |
-
loss_fmaps = feature_loss(fmaps_target, fmaps_candidate)
|
494 |
-
loss_gen, losses_gen = generator_loss(discriminator_candidate)
|
495 |
-
loss_gen=loss_gen * training_args.weight_gen
|
496 |
-
|
497 |
-
# loss_gen.backward(retain_graph=True)
|
498 |
-
loss_fmaps=loss_fmaps * training_args.weight_fmaps
|
499 |
-
|
500 |
-
# loss_fmaps.backward(retain_graph=True)
|
501 |
-
total_generator_loss = (
|
502 |
-
loss_duration
|
503 |
-
+ loss_mel
|
504 |
-
+ loss_kl
|
505 |
-
+ loss_fmaps
|
506 |
-
+ loss_gen
|
507 |
-
)
|
508 |
-
# total_generator_loss.backward()
|
509 |
-
optimizer.zero_grad()
|
510 |
-
scaler.scale(total_generator_loss).backward()
|
511 |
-
scaler.unscale_(optimizer)
|
512 |
-
grad_norm_g = clip_grad_value_(self.parameters(), None)
|
513 |
-
scaler.step(optimizer)
|
514 |
-
scaler.update()
|
515 |
-
loss_gen=grad_norm_g
|
516 |
-
|
517 |
-
return loss_gen,loss_des,loss_durationsall,loss_melall,loss_klall
|
518 |
-
|
519 |
|
520 |
|
521 |
-
|
522 |
-
|
523 |
-
loss_genall=0
|
524 |
-
loss_desall=0
|
525 |
-
loss_durationsall=0
|
526 |
-
loss_melall=0
|
527 |
-
loss_klall=0
|
528 |
-
loss_fmapsall=0
|
529 |
-
start_speeker,end_speeker=idspeakers
|
530 |
-
|
531 |
-
|
532 |
-
datatrain=obtrainer.DataSets['train'][index_db]
|
533 |
-
lr_scheduler,disc_lr_scheduler=obtrainer.lr_schedulers
|
534 |
-
lr_scheduler.step()
|
535 |
-
|
536 |
-
disc_lr_scheduler.step()
|
537 |
-
train_dataset,speaker_id=datatrain
|
538 |
-
print(f" Num Epochs = {epoch}, speaker_id DB ={speaker_id}")
|
539 |
-
num_div_proc=int(len(train_dataset)/10)+1
|
540 |
-
print(' -process traning : [',end='')
|
541 |
-
full_generation_sample =obtrainer.DataSets['full_generation'][full_generation_sample_index]
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
for step, batch in enumerate(train_dataset):
|
546 |
-
loss_gen,loss_des,loss_durationsa,loss_mela,loss_kl=train_step(batch,
|
547 |
-
models=obtrainer.models,
|
548 |
-
optimizers=obtrainer.optimizers,
|
549 |
-
training_args=obtrainer.training_args,
|
550 |
-
tools=obtrainer.tools)
|
551 |
-
loss_genall+=loss_gen
|
552 |
-
loss_desall+=loss_des
|
553 |
-
loss_durationsall+=loss_durationsa
|
554 |
-
loss_melall+=loss_mela
|
555 |
-
loss_klall+=loss_kl
|
556 |
-
|
557 |
-
obtrainer.global_step +=1
|
558 |
-
if step%num_div_proc==0:
|
559 |
-
print('==',end='')
|
560 |
-
|
561 |
-
# validation
|
562 |
-
|
563 |
-
do_eval = obtrainer.training_args.do_eval and (obtrainer.global_step % obtrainer.training_args.eval_steps == 0)
|
564 |
-
|
565 |
-
|
566 |
-
if do_eval:
|
567 |
-
speaker_id_c=int(torch.randint(start_speeker,end_speeker,size=(1,))[0])
|
568 |
-
model=obtrainer.models[0]
|
569 |
-
|
570 |
-
with torch.no_grad():
|
571 |
-
|
572 |
-
full_generation =model.forward(
|
573 |
-
input_ids =full_generation_sample["input_ids"],
|
574 |
-
attention_mask=full_generation_sample["attention_mask"],
|
575 |
-
speaker_id=speaker_id_c
|
576 |
-
)
|
577 |
-
|
578 |
-
full_generation_waveform = full_generation.waveform.cpu().numpy()
|
579 |
-
|
580 |
-
wandb.log({
|
581 |
-
"full generations samples": [
|
582 |
-
wandb.Audio(w.reshape(-1), caption=f"Full generation sample {epoch}", sample_rate=16000)
|
583 |
-
for w in full_generation_waveform],})
|
584 |
-
step+=1
|
585 |
-
# wandb.log({"train_losses":loss_melall})
|
586 |
-
wandb.log({"loss_gen":loss_genall/step})
|
587 |
-
wandb.log({"loss_des":loss_desall/step})
|
588 |
-
wandb.log({"loss_duration":loss_durationsall/step})
|
589 |
-
wandb.log({"loss_mel":loss_melall/step})
|
590 |
-
wandb.log({f"loss_kl_db{speaker_id}":loss_klall/step})
|
591 |
-
print(']',end='')
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
def load_training_args(path):
|
603 |
-
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, VITSTrainingArguments))
|
604 |
-
json_file = os.path.abspath(path)
|
605 |
-
model_args, data_args, training_args = parser.parse_json_file(json_file = json_file)
|
606 |
-
return training_args
|
607 |
-
def load_tools():
|
608 |
-
feature_extractor = VitsFeatureExtractor()
|
609 |
-
dict_state_grad_loss=get_state_grad_loss()
|
610 |
-
return feature_extractor,monotonic_align.maximum_path,dict_state_grad_loss
|
611 |
-
|
612 |
-
|
613 |
-
class TrinerModelVITS:
|
614 |
-
KC=0
|
615 |
-
def __init__(self,dir_model="",
|
616 |
-
path_training_args="",
|
617 |
-
train_dataset_dirs=[],
|
618 |
-
eval_dataset_dir="",
|
619 |
-
full_generation_dir="",
|
620 |
-
token="",
|
621 |
-
|
622 |
-
|
623 |
-
device="cpu"):
|
624 |
-
self.device=device
|
625 |
-
self.dir_model=dir_model
|
626 |
-
self.path_training_args=path_training_args
|
627 |
-
self.stute_mode=False
|
628 |
-
self.token=token
|
629 |
|
630 |
-
|
631 |
-
|
632 |
-
self.global_step=0
|
633 |
-
self.len_dataset=len(self.DataSets['train'])
|
634 |
-
#self.load_model()
|
635 |
-
#self.init_wandb()
|
636 |
-
# self.training_args=load_training_args(self.path_training_args)
|
637 |
-
# training_args= self.training_args
|
638 |
-
scaler = GradScaler(enabled=True)
|
639 |
-
# for disc in self.model.discriminator.discriminators:
|
640 |
-
# disc.apply_weight_norm()
|
641 |
-
# self.model.decoder.apply_weight_norm()
|
642 |
-
# # torch.nn.utils.weight_norm(self.decoder.conv_pre)
|
643 |
-
# # torch.nn.utils.weight_norm(self.decoder.conv_post)
|
644 |
-
# for flow in self.model.flow.flows:
|
645 |
-
# torch.nn.utils.weight_norm(flow.conv_pre)
|
646 |
-
# torch.nn.utils.weight_norm(flow.conv_post)
|
647 |
-
|
648 |
-
discriminator = self.model.discriminator
|
649 |
-
self.model.discriminator = None
|
650 |
-
self.models=(self.model,discriminator)
|
651 |
-
|
652 |
-
optimizer = torch.optim.AdamW(
|
653 |
-
self.model.parameters(),
|
654 |
-
2e-4,
|
655 |
-
betas=[0.8, 0.99],
|
656 |
-
# eps=training_args.adam_epsilon,
|
657 |
-
)
|
658 |
-
|
659 |
-
# Hack to be able to train on multiple device
|
660 |
-
disc_optimizer = torch.optim.AdamW(
|
661 |
-
discriminator.parameters(),
|
662 |
-
2e-4,
|
663 |
-
betas=[0.8, 0.99],
|
664 |
-
# eps=training_args.adam_epsilon,
|
665 |
-
)
|
666 |
-
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
|
667 |
-
optimizer,gamma=0.999875, last_epoch=-1
|
668 |
-
)
|
669 |
-
disc_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
|
670 |
-
disc_optimizer, gamma=0.999875,last_epoch=-1
|
671 |
-
)
|
672 |
-
# self.models=(self.model,discriminator)
|
673 |
-
self.optimizers=(optimizer,disc_optimizer,scaler)
|
674 |
-
self.lr_schedulers=(lr_scheduler,disc_lr_scheduler)
|
675 |
-
self.tools=load_tools()
|
676 |
-
self.stute_mode=True
|
677 |
-
print(self.lr_schedulers)
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
def init_Starting(self):
|
683 |
-
print('init_Starting')
|
684 |
-
#self.training_args=load_training_args(self.path_training_args)
|
685 |
-
#self.stute_mode=False
|
686 |
-
print('end training_args')
|
687 |
-
|
688 |
-
|
689 |
-
def init_training(self):
|
690 |
-
|
691 |
-
|
692 |
-
self.initialize_training_components()
|
693 |
-
# self.epoch_count=0
|
694 |
-
|
695 |
-
|
696 |
-
def load_model(self):
|
697 |
-
self.model=VitsModel.from_pretrained(self.dir_model,token=self.token).to(self.device)
|
698 |
-
self.model.setMfA(monotonic_align.maximum_path)
|
699 |
-
|
700 |
-
def init_wandb(self):
|
701 |
-
wandb.login(key= "782b6a6e82bbb5a5348de0d3c7d40d1e76351e79")
|
702 |
-
#config = self.training_args.to_dict()
|
703 |
-
wandb.init(project= 'HugfaceTraining')
|
704 |
-
|
705 |
-
def load_modell(self,namemodel):
|
706 |
-
self.model=VitsModel.from_pretrained(namemodel,token=self.token).to(self.device)
|
707 |
-
return "true"
|
708 |
-
def load_dataset(self,train_dataset_dirs,eval_dataset_dir,full_generation_dir):
|
709 |
-
ctrain_datasets,eval_dataset,full_generation_dataset=get_data_loader(train_dataset_dirs = train_dataset_dirs,
|
710 |
-
eval_dataset_dir =eval_dataset_dir ,
|
711 |
-
full_generation_dir =full_generation_dir ,
|
712 |
-
device=self.device)
|
713 |
-
self.DataSets={'train':ctrain_datasets,'eval':eval_dataset,'full_generation':full_generation_dataset}
|
714 |
-
|
715 |
-
|
716 |
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
|
721 |
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
set_seed(training_args.seed)
|
737 |
-
# scaler = GradScaler(enabled=training_args.fp16)
|
738 |
|
739 |
-
|
740 |
-
|
741 |
-
#
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
# )
|
760 |
-
|
761 |
-
# # Hack to be able to train on multiple device
|
762 |
-
# disc_optimizer = torch.optim.AdamW(
|
763 |
-
# discriminator.parameters(),
|
764 |
-
# training_args.d_learning_rate,
|
765 |
-
# betas=[training_args.d_adam_beta1, training_args.d_adam_beta2],
|
766 |
-
# eps=training_args.adam_epsilon,
|
767 |
-
# )
|
768 |
-
# lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
|
769 |
-
# optimizer, gamma=training_args.lr_decay, last_epoch=-1
|
770 |
-
# )
|
771 |
-
# disc_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
|
772 |
-
# disc_optimizer, gamma=training_args.lr_decay, last_epoch=-1
|
773 |
-
# )
|
774 |
-
# # self.models=(self.model,discriminator)
|
775 |
-
# self.optimizers=(optimizer,disc_optimizer,scaler)
|
776 |
-
# self.lr_schedulers=(lr_scheduler,disc_lr_scheduler)
|
777 |
-
# self.tools=load_tools()
|
778 |
-
# self.stute_mode=True
|
779 |
-
# print(self.lr_schedulers)
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
def save_pretrained(self,path_save_model):
|
786 |
-
|
787 |
-
model,discriminator=self.models
|
788 |
-
|
789 |
-
model.discriminator=discriminator
|
790 |
-
for disc in model.discriminator.discriminators:
|
791 |
-
disc.remove_weight_norm()
|
792 |
-
model.decoder.remove_weight_norm()
|
793 |
-
# torch.nn.utils.remove_weight_norm(self.decoder.conv_pre)
|
794 |
-
# torch.nn.utils.remove_weight_norm(self.decoder.conv_post)
|
795 |
-
for flow in model.flow.flows:
|
796 |
-
torch.nn.utils.remove_weight_norm(flow.conv_pre)
|
797 |
-
torch.nn.utils.remove_weight_norm(flow.conv_post)
|
798 |
-
|
799 |
-
model.push_to_hub(path_save_model,token=self.token)
|
800 |
-
|
801 |
-
|
802 |
-
def run_train_epoch(self):
|
803 |
-
index_db=self.epoch_count%self.len_dataset
|
804 |
-
train_epoch(self,index_db=index_db,epoch=self.epoch_count,idspeakers=(0,1),full_generation_sample_index=-1)
|
805 |
-
self.epoch_count+=1
|
806 |
-
return f'epoch_count:{self.epoch_count},global_step:{self.global_step},index_db"{index_db}'
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
# return (self.model,discriminator),(optimizer, disc_optimizer), (lr_scheduler, disc_lr_scheduler)
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
# logger.info("***** Training / Inference Done *****")
|
819 |
-
def modelspeech(texts):
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
inputs = tokenizer(texts, return_tensors="pt")#.cuda()
|
824 |
-
|
825 |
-
wav = model_vits(input_ids=inputs["input_ids"]).waveform#.detach()
|
826 |
-
# display(Audio(wav, rate=model.config.sampling_rate))
|
827 |
-
return model_vits.config.sampling_rate,wav#remove_noise_nr(wav)
|
828 |
-
|
829 |
-
dataset_dir='ABThag-db'
|
830 |
-
train_dataset_dirs=[
|
831 |
-
# ('/content/drive/MyDrive/vitsM/DATA/fahd_db',0),
|
832 |
-
# ('/content/drive/MyDrive/vitsM/DATA/fahd_db',0),
|
833 |
-
# ('/content/drive/MyDrive/vitsM/DB2KKKK',1),
|
834 |
-
# ('/content/drive/MyDrive/vitsM/DATA/Db_Amgd_50_Bitch10',0),
|
835 |
-
# ('/content/drive/MyDrive/vitsM/DB2KKKK',1), #
|
836 |
-
# ('/content/drive/MyDrive/vitsM/DATA/Db_Amgd_50_Bitch10',0),
|
837 |
-
# ('/content/drive/MyDrive/vitsM/DATA/DBWfaa-Bitch:8-Count:60',0),
|
838 |
-
# ('/content/drive/MyDrive/vitsM/DATA/Wafa/b10r',0),
|
839 |
-
# ('/content/drive/MyDrive/vitsM/DATA/Wafa/b16r',0),
|
840 |
-
# ('/content/drive/MyDrive/vitsM/DATA/Wafa/b4',0),
|
841 |
-
|
842 |
-
# ('/content/drive/MyDrive/vitsM/DATA/fahd_db',None),
|
843 |
-
# ('/content/drive/MyDrive/vitsM/DATA/wafa-db',None),
|
844 |
-
# ('/content/drive/MyDrive/vitsM/DATA/wafa-db',4),
|
845 |
-
# ('/content/drive/MyDrive/vitsM/DATA/DB-ABThag-Bitch:5-Count-37',4),
|
846 |
-
# ('/content/drive/MyDrive/vitsM/DB-300-k',6),
|
847 |
-
('databatchs',0),
|
848 |
-
#('/content/drive/MyDrive/dataset_ljBatchs',0),
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
]
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
dir_model='wasmdashai/vits-ar-huba-fine'
|
861 |
-
pro=TrinerModelVITS(dir_model=dir_model,
|
862 |
-
path_training_args='VitsModelSplit/finetune_config_ara.json',
|
863 |
-
train_dataset_dirs = train_dataset_dirs,
|
864 |
-
eval_dataset_dir = os.path.join(dataset_dir,'eval'),
|
865 |
-
full_generation_dir = os.path.join(dataset_dir,'full_generation'),
|
866 |
-
token=token,
|
867 |
-
device=device
|
868 |
-
)
|
869 |
-
def loadd_d():
|
870 |
-
token=os.environ.get("key_")
|
871 |
-
#model=VitsModel.from_pretrained(n_model,token=token)
|
872 |
-
return token
|
873 |
-
@spaces.GPU(duration=30)
|
874 |
-
def run_train_epoch(num):
|
875 |
-
TrinerModelVITS.KC+=1
|
876 |
-
if num >0:
|
877 |
-
pro.init_training()
|
878 |
-
for i in range(num):
|
879 |
-
# model.train(True)
|
880 |
-
return pro.run_train_epoch() +f'- kc={TrinerModelVITS.KC}'
|
881 |
-
else:
|
882 |
-
pro.save_pretrained(pro.dir_model)
|
883 |
-
pro.load_model()
|
884 |
-
return 'save model '
|
885 |
-
|
886 |
-
@spaces.GPU
|
887 |
-
def init_training():
|
888 |
-
pro.init_training()
|
889 |
-
return pro.dir_model,'init_training'
|
890 |
-
|
891 |
-
@spaces.GPU
|
892 |
-
def init_Starting():
|
893 |
-
|
894 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
895 |
-
|
896 |
|
897 |
-
|
898 |
-
|
899 |
-
|
900 |
-
pro.init_wandb()
|
901 |
-
return 'init_wandb'
|
902 |
-
|
903 |
-
def save_pretrained(path):
|
904 |
-
pro.save_pretrained(path)
|
905 |
|
906 |
-
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
-
|
914 |
-
|
915 |
-
|
916 |
-
|
917 |
-
|
918 |
-
|
919 |
-
|
920 |
-
|
921 |
-
|
922 |
-
|
923 |
-
|
924 |
-
|
925 |
-
|
926 |
-
|
927 |
-
|
928 |
-
|
929 |
-
|
930 |
-
|
931 |
-
|
932 |
-
|
933 |
-
|
934 |
-
|
935 |
-
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
btn_save_pretrained = gr.Button("save_pretrained")
|
940 |
-
input_save_pretrained = gr.Textbox(label="save_pretrained")
|
941 |
-
output_save_pretrained = gr.Textbox(label="save_pretrained")
|
942 |
-
btn_save_pretrained.click(fn=save_pretrained,inputs=[input_save_pretrained],outputs=[output_save_pretrained])
|
943 |
-
|
944 |
-
btn_init_train.click(fn=init_training,inputs=[],outputs=[input_save_pretrained,output_btn_init_train])
|
945 |
|
946 |
-
|
947 |
-
|
948 |
-
|
949 |
-
|
950 |
-
interface.launch()
|
951 |
-
print('loadeed')
|
952 |
-
|
|
|
1 |
+
from transformers import MllamaForConditionalGeneration, AutoProcessor, TextIteratorStreamer
|
2 |
+
from PIL import Image
|
3 |
+
import requests
|
|
|
|
|
4 |
import torch
|
5 |
+
from threading import Thread
|
6 |
+
import gradio as gr
|
7 |
+
from gradio import FileData
|
8 |
+
import time
|
9 |
+
import spaces
|
10 |
import os
|
|
|
|
|
11 |
token=os.environ.get("key_")
|
12 |
+
ckpt = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
13 |
+
model = MllamaForConditionalGeneration.from_pretrained(ckpt,token=token,
|
14 |
+
torch_dtype=torch.bfloat16).to("cuda")
|
15 |
+
processor = AutoProcessor.from_pretrained(ckpt,token=token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
+
#@spaces.GPU
|
19 |
+
def bot_streaming(message, history, max_new_tokens=250):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
txt = message["text"]
|
22 |
+
ext_buffer = f"{txt}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
messages= []
|
25 |
+
images = []
|
|
|
26 |
|
27 |
|
28 |
+
for i, msg in enumerate(history):
|
29 |
+
if isinstance(msg[0], tuple):
|
30 |
+
messages.append({"role": "user", "content": [{"type": "text", "text": history[i+1][0]}, {"type": "image"}]})
|
31 |
+
messages.append({"role": "assistant", "content": [{"type": "text", "text": history[i+1][1]}]})
|
32 |
+
images.append(Image.open(msg[0][0]).convert("RGB"))
|
33 |
+
elif isinstance(history[i-1], tuple) and isinstance(msg[0], str):
|
34 |
+
# messages are already handled
|
35 |
+
pass
|
36 |
+
elif isinstance(history[i-1][0], str) and isinstance(msg[0], str): # text only turn
|
37 |
+
messages.append({"role": "user", "content": [{"type": "text", "text": msg[0]}]})
|
38 |
+
messages.append({"role": "assistant", "content": [{"type": "text", "text": msg[1]}]})
|
39 |
+
|
40 |
+
# add current message
|
41 |
+
if len(message["files"]) == 1:
|
|
|
|
|
42 |
|
43 |
+
if isinstance(message["files"][0], str): # examples
|
44 |
+
image = Image.open(message["files"][0]).convert("RGB")
|
45 |
+
else: # regular input
|
46 |
+
image = Image.open(message["files"][0]["path"]).convert("RGB")
|
47 |
+
images.append(image)
|
48 |
+
messages.append({"role": "user", "content": [{"type": "text", "text": txt}, {"type": "image"}]})
|
49 |
+
else:
|
50 |
+
messages.append({"role": "user", "content": [{"type": "text", "text": txt}]})
|
51 |
+
|
52 |
+
|
53 |
+
texts = processor.apply_chat_template(messages, add_generation_prompt=True)
|
54 |
+
|
55 |
+
if images == []:
|
56 |
+
inputs = processor(text=texts, return_tensors="pt").to("cuda")
|
57 |
+
else:
|
58 |
+
inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda")
|
59 |
+
streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
|
60 |
+
|
61 |
+
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
|
62 |
+
generated_text = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
65 |
+
thread.start()
|
66 |
+
buffer = ""
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
for new_text in streamer:
|
69 |
+
buffer += new_text
|
70 |
+
generated_text_without_prompt = buffer
|
71 |
+
time.sleep(0.01)
|
72 |
+
yield buffer
|
73 |
+
|
74 |
+
|
75 |
+
demo = gr.ChatInterface(fn=bot_streaming, title="Multimodal Llama", examples=[
|
76 |
+
[{"text": "Which era does this piece belong to? Give details about the era.", "files":["./examples/rococo.jpg"]},
|
77 |
+
200],
|
78 |
+
[{"text": "Where do the droughts happen according to this diagram?", "files":["./examples/weather_events.png"]},
|
79 |
+
250],
|
80 |
+
[{"text": "What happens when you take out white cat from this chain?", "files":["./examples/ai2d_test.jpg"]},
|
81 |
+
250],
|
82 |
+
[{"text": "How long does it take from invoice date to due date? Be short and concise.", "files":["./examples/invoice.png"]},
|
83 |
+
250],
|
84 |
+
[{"text": "Where to find this monument? Can you give me other recommendations around the area?", "files":["./examples/wat_arun.jpg"]},
|
85 |
+
250],
|
86 |
+
],
|
87 |
+
textbox=gr.MultimodalTextbox(),
|
88 |
+
additional_inputs = [gr.Slider(
|
89 |
+
minimum=10,
|
90 |
+
maximum=500,
|
91 |
+
value=250,
|
92 |
+
step=10,
|
93 |
+
label="Maximum number of new tokens to generate",
|
94 |
+
)
|
95 |
+
],
|
96 |
+
cache_examples=False,
|
97 |
+
description="Try Multimodal Llama by Meta with transformers in this demo. Upload an image, and start chatting about it, or simply try one of the examples below. To learn more about Llama Vision, visit [our blog post](https://huggingface.co/blog/llama32). ",
|
98 |
+
stop_btn="Stop Generation",
|
99 |
+
fill_height=True,
|
100 |
+
multimodal=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
+
demo.launch(debug=True)
|
|
|
|
|
|
|
|
|
|
|
|