Niki Zhang commited on
Commit
18ab0a5
·
verified ·
1 Parent(s): 1af18bd

Update tts.py

Browse files
Files changed (1) hide show
  1. tts.py +54 -808
tts.py CHANGED
@@ -1,811 +1,57 @@
1
  import os
2
- import json
 
3
  import gradio as gr
4
- import numpy as np
5
- from gradio import processing_utils
6
- import requests
7
- from packaging import version
8
- from PIL import Image, ImageDraw
9
- import functools
10
- from langchain.llms.openai import OpenAI
11
- from caption_anything.model import CaptionAnything
12
- from caption_anything.utils.image_editing_utils import create_bubble_frame
13
- from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter, image_resize
14
- from caption_anything.utils.parser import parse_augment
15
- from caption_anything.captioner import build_captioner
16
- from caption_anything.text_refiner import build_text_refiner
17
- from caption_anything.segmenter import build_segmenter
18
- from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
19
- from segment_anything import sam_model_registry
20
- import easyocr
21
- import tts
22
-
23
-
24
- gpt_state = 0
25
-
26
- article = """
27
- <div style='margin:20px auto;'>
28
- <p>By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml</p>
29
- </div>
30
- """
31
-
32
- args = parse_augment()
33
- args.segmenter = "huge"
34
- args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
35
- args.clip_filter = True
36
- if args.segmenter_checkpoint is None:
37
- _, segmenter_checkpoint = prepare_segmenter(args.segmenter)
38
- else:
39
- segmenter_checkpoint = args.segmenter_checkpoint
40
-
41
- shared_captioner = build_captioner(args.captioner, args.device, args)
42
- shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
43
- ocr_lang = ["ch_tra", "en"]
44
- shared_ocr_reader = easyocr.Reader(ocr_lang)
45
- tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
46
- shared_chatbot_tools = build_chatbot_tools(tools_dict)
47
-
48
-
49
- class ImageSketcher(gr.Image):
50
- """
51
- Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
52
- """
53
-
54
- is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
55
-
56
- def __init__(self, **kwargs):
57
- super().__init__(tool="sketch", **kwargs)
58
-
59
- def preprocess(self, x):
60
- if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
61
- assert isinstance(x, dict)
62
- if x['mask'] is None:
63
- decode_image = processing_utils.decode_base64_to_image(x['image'])
64
- width, height = decode_image.size
65
- mask = np.zeros((height, width, 4), dtype=np.uint8)
66
- mask[..., -1] = 255
67
- mask = self.postprocess(mask)
68
- x['mask'] = mask
69
- return super().preprocess(x)
70
-
71
-
72
- def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, ocr_reader=None, text_refiner=None,
73
- session_id=None):
74
- segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
75
- captioner = captioner
76
- if session_id is not None:
77
- print('Init caption anything for session {}'.format(session_id))
78
- return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, ocr_reader=ocr_reader, text_refiner=text_refiner)
79
-
80
-
81
- def validate_api_key(api_key):
82
- api_key = str(api_key).strip()
83
- print(api_key)
84
- try:
85
- test_llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
86
- response = test_llm("Test API call")
87
- print(response)
88
- return True
89
- except Exception as e:
90
- print(f"API key validation failed: {e}")
91
- return False
92
-
93
-
94
- def init_openai_api_key(api_key=""):
95
- text_refiner = None
96
- visual_chatgpt = None
97
- if api_key and len(api_key) > 30:
98
- print(api_key)
99
- if validate_api_key(api_key):
100
- try:
101
- text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
102
- assert len(text_refiner.llm('hi')) > 0 # test
103
- visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key)
104
- except Exception as e:
105
- print(f"Error initializing TextRefiner or ConversationBot: {e}")
106
- text_refiner = None
107
- visual_chatgpt = None
108
- else:
109
- print("Invalid API key.")
110
- else:
111
- print("API key is too short.")
112
- print(text_refiner)
113
- openai_available = text_refiner is not None
114
- if openai_available:
115
-
116
- global gpt_state
117
- gpt_state=1
118
- return [gr.update(visible=True)]+[gr.update(visible=False)]+[gr.update(visible=True)]*3+[gr.update(visible=False)]+ [gr.update(visible=True)]+ [gr.update(visible=False)]*2 + [text_refiner, visual_chatgpt, None]
119
- else:
120
- return [gr.update(visible=False)]*7 + [gr.update(visible=True)]*2 + [text_refiner, visual_chatgpt, 'Your OpenAI API Key is not available']
121
-
122
- def init_wo_openai_api_key():
123
- return [gr.update(visible=False)]*4 + [gr.update(visible=True)]+ [gr.update(visible=False)]+[gr.update(visible=True)]+[gr.update(visible=False)]*2 + [None, None, None]
124
-
125
- def get_click_prompt(chat_input, click_state, click_mode):
126
- inputs = json.loads(chat_input)
127
- if click_mode == 'Continuous':
128
- points = click_state[0]
129
- labels = click_state[1]
130
- for input in inputs:
131
- points.append(input[:2])
132
- labels.append(input[2])
133
- elif click_mode == 'Single':
134
- points = []
135
- labels = []
136
- for input in inputs:
137
- points.append(input[:2])
138
- labels.append(input[2])
139
- click_state[0] = points
140
- click_state[1] = labels
141
- else:
142
- raise NotImplementedError
143
-
144
- prompt = {
145
- "prompt_type": ["click"],
146
- "input_point": click_state[0],
147
- "input_label": click_state[1],
148
- "multimask_output": "True",
149
- }
150
- return prompt
151
-
152
-
153
- def update_click_state(click_state, caption, click_mode):
154
- if click_mode == 'Continuous':
155
- click_state[2].append(caption)
156
- elif click_mode == 'Single':
157
- click_state[2] = [caption]
158
- else:
159
- raise NotImplementedError
160
-
161
- def chat_input_callback(*args):
162
- visual_chatgpt, chat_input, click_state, state, aux_state = args
163
- if visual_chatgpt is not None:
164
- return visual_chatgpt.run_text(chat_input, state, aux_state)
165
- else:
166
- response = "Text refiner is not initilzed, please input openai api key."
167
- state = state + [(chat_input, response)]
168
- return state, state
169
-
170
-
171
-
172
- def upload_callback(image_input, state, visual_chatgpt=None):
173
-
174
- if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
175
- image_input, mask = image_input['image'], image_input['mask']
176
-
177
- click_state = [[], [], []]
178
- image_input = image_resize(image_input, res=1024)
179
-
180
- model = build_caption_anything_with_models(
181
- args,
182
- api_key="",
183
- captioner=shared_captioner,
184
- sam_model=shared_sam_model,
185
- ocr_reader=shared_ocr_reader,
186
- session_id=iface.app_id
187
- )
188
- model.segmenter.set_image(image_input)
189
- image_embedding = model.image_embedding
190
- original_size = model.original_size
191
- input_size = model.input_size
192
-
193
- if visual_chatgpt is not None:
194
- print('upload_callback: add caption to chatGPT memory')
195
- new_image_path = get_new_image_name('chat_image', func_name='upload')
196
- image_input.save(new_image_path)
197
- visual_chatgpt.current_image = new_image_path
198
- img_caption = model.captioner.inference(image_input, filter=False, args={'text_prompt':''})['caption']
199
- Human_prompt = f'\nHuman: The description of the image with path {new_image_path} is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
200
- AI_prompt = "Received."
201
- visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
202
- visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
203
- state = [(None, 'Received new image, resize it to width {} and height {}: '.format(image_input.size[0], image_input.size[1]))]
204
-
205
- return state, state, image_input, click_state, image_input, image_input, image_embedding, \
206
- original_size, input_size
207
-
208
-
209
-
210
- def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
211
- length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
212
- out_state, click_index_state, input_mask_state, input_points_state, input_labels_state, evt: gr.SelectData):
213
- click_index = evt.index
214
-
215
- if point_prompt == 'Positive':
216
- coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1]))
217
- else:
218
- coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1]))
219
-
220
- prompt = get_click_prompt(coordinate, click_state, click_mode)
221
- input_points = prompt['input_point']
222
- input_labels = prompt['input_label']
223
-
224
- controls = {'length': length,
225
- 'sentiment': sentiment,
226
- 'factuality': factuality,
227
- 'language': language}
228
-
229
- model = build_caption_anything_with_models(
230
- args,
231
- api_key="",
232
- captioner=shared_captioner,
233
- sam_model=shared_sam_model,
234
- ocr_reader=shared_ocr_reader,
235
- text_refiner=text_refiner,
236
- session_id=iface.app_id
237
- )
238
-
239
- model.setup(image_embedding, original_size, input_size, is_image_set=True)
240
-
241
- enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
242
- out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0]
243
-
244
- state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
245
- update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
246
- text = out['generated_captions']['raw_caption']
247
- input_mask = np.array(out['mask'].convert('P'))
248
- image_input = mask_painter(np.array(image_input), input_mask)
249
-
250
- click_index_state = click_index
251
- input_mask_state = input_mask
252
- input_points_state = input_points
253
- input_labels_state = input_labels
254
- out_state = out
255
-
256
- if visual_chatgpt is not None:
257
- print('inference_click: add caption to chatGPT memory')
258
- new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
259
- Image.open(out["crop_save_path"]).save(new_crop_save_path)
260
- point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
261
- visual_chatgpt.point_prompt = point_prompt
262
-
263
- generated_caption = text
264
- print(generated_caption)
265
-
266
- yield state, state, click_state, image_input, generated_caption, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state
267
-
268
-
269
-
270
-
271
- def submit_caption(image_input, state, generated_caption, text_refiner, visual_chatgpt, enable_wiki, length, sentiment, factuality, language,
272
- out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
273
- input_text, input_language, input_audio, input_mic, use_mic, agree):
274
- print("state",state)
275
-
276
- click_index = click_index_state
277
- input_mask = input_mask_state
278
- input_points = input_points_state
279
- input_labels = input_labels_state
280
- out = out_state
281
- print("click",click_index)
282
-
283
- origin_image_input = image_input
284
-
285
- controls = {
286
- 'length': length,
287
- 'sentiment': sentiment,
288
- 'factuality': factuality,
289
- 'language': language
290
- }
291
-
292
- image_input = create_bubble_frame(np.array(image_input), generated_caption, click_index, input_mask,
293
- input_points=input_points, input_labels=input_labels)
294
-
295
- if generated_caption:
296
- state = state + [(None, f"RAW_Caption: {generated_caption}")]
297
-
298
-
299
- if not args.disable_gpt and text_refiner:
300
- refined_caption = text_refiner.inference(query=generated_caption, controls=controls, context=out['context_captions'], enable_wiki=enable_wiki)
301
- new_cap = refined_caption['caption']
302
- if refined_caption.get('wiki'):
303
- state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
304
- state = state + [(None, f"GPT_Caption: {new_cap}")]
305
- print("new_cap",new_cap)
306
- refined_image_input = create_bubble_frame(np.array(origin_image_input), new_cap, click_index, input_mask,
307
- input_points=input_points, input_labels=input_labels)
308
- try:
309
- waveform_visual, audio_output = tts.predict(new_cap, input_language, input_audio, input_mic, use_mic, agree)
310
- return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
311
- except Exception as e:
312
- state = state + [(None, f"Error during TTS prediction: {str(e)}")]
313
- print(f"Error during TTS prediction: {str(e)}")
314
- return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
315
-
316
- else:
317
- try:
318
- waveform_visual, audio_output = tts.predict(generated_caption, input_language, input_audio, input_mic, use_mic, agree)
319
- return state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
320
- except Exception as e:
321
- state = state + [(None, f"Error during TTS prediction: {str(e)}")]
322
- print(f"Error during TTS prediction: {str(e)}")
323
- return state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
324
-
325
-
326
-
327
-
328
- def txt2speech(text):
329
- print("Initializing text-to-speech conversion...")
330
- # API_URL = "https://api-inference.huggingface.co/models/espnet/kan-bayashi_ljspeech_vits"
331
- # headers = {"Authorization": f"Bearer {os.environ['HUGGINGFACEHUB_API_TOKEN']}"}
332
- # payloads = {'inputs': text}
333
- # response = requests.post(API_URL, headers=headers, json=payloads)
334
- # with open('audio_story.mp3', 'wb') as file:
335
- # file.write(response.content)
336
- print("Text-to-speech conversion completed.")
337
-
338
-
339
-
340
- def get_sketch_prompt(mask: Image.Image):
341
- """
342
- Get the prompt for the sketcher.
343
- TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
344
- """
345
-
346
- mask = np.asarray(mask)[..., 0]
347
-
348
- # Get the bounding box of the sketch
349
- y, x = np.where(mask != 0)
350
- x1, y1 = np.min(x), np.min(y)
351
- x2, y2 = np.max(x), np.max(y)
352
-
353
- prompt = {
354
- 'prompt_type': ['box'],
355
- 'input_boxes': [
356
- [x1, y1, x2, y2]
357
- ]
358
- }
359
-
360
- return prompt
361
-
362
-
363
- def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
364
- original_size, input_size, text_refiner):
365
- image_input, mask = sketcher_image['image'], sketcher_image['mask']
366
-
367
- prompt = get_sketch_prompt(mask)
368
- boxes = prompt['input_boxes']
369
-
370
- controls = {'length': length,
371
- 'sentiment': sentiment,
372
- 'factuality': factuality,
373
- 'language': language}
374
-
375
- model = build_caption_anything_with_models(
376
- args,
377
- api_key="",
378
- captioner=shared_captioner,
379
- sam_model=shared_sam_model,
380
- ocr_reader=shared_ocr_reader,
381
- text_refiner=text_refiner,
382
- session_id=iface.app_id
383
- )
384
-
385
- model.setup(image_embedding, original_size, input_size, is_image_set=True)
386
-
387
- enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
388
- out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)[0]
389
-
390
- # Update components and states
391
- state.append((f'Box: {boxes}', None))
392
- state.append((None, f'raw_caption: {out["generated_captions"]["raw_caption"]}'))
393
- text = out['generated_captions']['raw_caption']
394
- input_mask = np.array(out['mask'].convert('P'))
395
- image_input = mask_painter(np.array(image_input), input_mask)
396
-
397
- origin_image_input = image_input
398
-
399
- fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
400
- image_input = create_bubble_frame(image_input, "", fake_click_index, input_mask)
401
-
402
- yield state, state, image_input
403
-
404
- if not args.disable_gpt and model.text_refiner:
405
- refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
406
- enable_wiki=enable_wiki)
407
-
408
- new_cap = refined_caption['caption']
409
- if refined_caption['wiki']:
410
- state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
411
- state = state + [(None, f"caption: {new_cap}")]
412
- refined_image_input = create_bubble_frame(origin_image_input, new_cap, fake_click_index, input_mask)
413
-
414
- yield state, state, refined_image_input
415
-
416
- def clear_chat_memory(visual_chatgpt, keep_global=False):
417
- if visual_chatgpt is not None:
418
- visual_chatgpt.memory.clear()
419
- visual_chatgpt.point_prompt = ""
420
- if keep_global:
421
- visual_chatgpt.agent.memory.buffer = visual_chatgpt.global_prompt
422
- else:
423
- visual_chatgpt.current_image = None
424
- visual_chatgpt.global_prompt = ""
425
-
426
- def cap_everything(image_input, visual_chatgpt, text_refiner,input_language, input_audio, input_mic, use_mic, agree):
427
-
428
- model = build_caption_anything_with_models(
429
- args,
430
- api_key="",
431
- captioner=shared_captioner,
432
- sam_model=shared_sam_model,
433
- ocr_reader=shared_ocr_reader,
434
- text_refiner=text_refiner,
435
- session_id=iface.app_id
436
- )
437
- paragraph = model.inference_cap_everything(image_input, verbose=True)
438
- # state = state + [(None, f"Caption Everything: {paragraph}")]
439
- Human_prompt = f'\nThe description of the image with path {visual_chatgpt.current_image} is:\n{paragraph}\nThis information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
440
- AI_prompt = "Received."
441
- visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
442
- visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
443
- waveform_visual, audio_output=tts.predict(paragraph, input_language, input_audio, input_mic, use_mic, agree)
444
- return paragraph,waveform_visual, audio_output
445
-
446
-
447
- def get_style():
448
- current_version = version.parse(gr.__version__)
449
- if current_version <= version.parse('3.24.1'):
450
- style = '''
451
- #image_sketcher{min-height:500px}
452
- #image_sketcher [data-testid="image"], #image_sketcher [data-testid="image"] > div{min-height: 500px}
453
- #image_upload{min-height:500px}
454
- #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 500px}
455
- '''
456
- elif current_version <= version.parse('3.27'):
457
- style = '''
458
- #image_sketcher{min-height:500px}
459
- #image_upload{min-height:500px}
460
- '''
461
  else:
462
- style = None
463
-
464
- return style
465
-
466
-
467
- def create_ui():
468
- title = """<p><h1 align="center">EyeSee Anything in Art</h1></p>
469
- """
470
- description = """<p>Gradio demo for EyeSee Anything in Art, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. """
471
-
472
- examples = [
473
- ["test_images/img36.webp"],
474
- ["test_images/MUS.png"],
475
- ["test_images/图片2.png"],
476
- ["test_images/img5.jpg"],
477
- ["test_images/img14.jpg"],
478
- ["test_images/qingming3.jpeg"],
479
-
480
- ]
481
-
482
- with gr.Blocks(
483
- css=get_style()
484
- ) as iface:
485
- state = gr.State([])
486
- out_state = gr.State(None)
487
- click_state = gr.State([[], [], []])
488
- origin_image = gr.State(None)
489
- image_embedding = gr.State(None)
490
- text_refiner = gr.State(None)
491
- visual_chatgpt = gr.State(None)
492
- original_size = gr.State(None)
493
- input_size = gr.State(None)
494
- generated_caption = gr.State("")
495
- aux_state = gr.State([])
496
- click_index_state = gr.State((0, 0))
497
- input_mask_state = gr.State(np.zeros((1, 1)))
498
- input_points_state = gr.State([])
499
- input_labels_state = gr.State([])
500
-
501
-
502
- gr.Markdown(title)
503
- gr.Markdown(description)
504
-
505
- with gr.Row():
506
- with gr.Column(scale=1.0):
507
- with gr.Column(visible=False) as modules_not_need_gpt:
508
- with gr.Tab("Base(GPT Power)",visible=False) as base_tab:
509
- image_input_base = gr.Image(type="pil", interactive=True, elem_id="image_upload")
510
- example_image = gr.Image(type="pil", interactive=False, visible=False)
511
-
512
-
513
- with gr.Tab("Click") as click_tab:
514
- modules_not_need_gpt2=True
515
- image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
516
- example_image = gr.Image(type="pil", interactive=False, visible=False)
517
- with gr.Row(scale=1.0):
518
- with gr.Row(scale=0.4):
519
- point_prompt = gr.Radio(
520
- choices=["Positive", "Negative"],
521
- value="Positive",
522
- label="Point Prompt",
523
- interactive=True)
524
- click_mode = gr.Radio(
525
- choices=["Continuous", "Single"],
526
- value="Continuous",
527
- label="Clicking Mode",
528
- interactive=True)
529
- with gr.Row(scale=0.4):
530
- clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
531
- clear_button_image = gr.Button(value="Clear Image", interactive=True)
532
- submit_button_click=gr.Button(value="Submit", interactive=True)
533
- with gr.Tab("Trajectory (beta)"):
534
- sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=20,
535
- elem_id="image_sketcher")
536
- with gr.Row():
537
- submit_button_sketcher = gr.Button(value="Submit", interactive=True)
538
-
539
- with gr.Column(visible=False) as modules_need_gpt1:
540
- with gr.Row(scale=1.0):
541
- language = gr.Dropdown(
542
- ['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
543
- value="English", label="Language", interactive=True)
544
- sentiment = gr.Radio(
545
- choices=["Positive", "Natural", "Negative"],
546
- value="Natural",
547
- label="Sentiment",
548
- interactive=True,
549
- )
550
- with gr.Row(scale=1.0):
551
- factuality = gr.Radio(
552
- choices=["Factual", "Imagination"],
553
- value="Factual",
554
- label="Factuality",
555
- interactive=True,
556
- )
557
- length = gr.Slider(
558
- minimum=10,
559
- maximum=80,
560
- value=10,
561
- step=1,
562
- interactive=True,
563
- label="Generated Caption Length",
564
- )
565
- # 是否启用wiki内容整合到caption中
566
- enable_wiki = gr.Radio(
567
- choices=["Yes", "No"],
568
- value="No",
569
- label="Enable Wiki",
570
- interactive=True)
571
- # with gr.Column(visible=True) as modules_not_need_gpt3:
572
- gr.Examples(
573
- examples=examples,
574
- inputs=[example_image],
575
- )
576
-
577
- with gr.Column(scale=0.5):
578
- with gr.Column(visible=True) as module_key_input:
579
- openai_api_key = gr.Textbox(
580
- placeholder="Input openAI API key",
581
- show_label=False,
582
- label="OpenAI API Key",
583
- lines=1,
584
- type="password")
585
- with gr.Row(scale=0.5):
586
- enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
587
- disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True,
588
- variant='primary')
589
- with gr.Column(visible=False) as module_notification_box:
590
- notification_box = gr.Textbox(lines=1, label="Notification", max_lines=5, show_label=False)
591
-
592
- with gr.Column():
593
- with gr.Column(visible=False) as modules_need_gpt2:
594
- paragraph_output = gr.Textbox(lines=7, label="Describe Everything", max_lines=7)
595
- with gr.Column(visible=False) as modules_need_gpt0:
596
- cap_everything_button = gr.Button(value="Caption Everything in a Paragraph", interactive=True)
597
-
598
- with gr.Column(visible=False) as modules_not_need_gpt2:
599
- chatbot = gr.Chatbot(label="Chatbox", ).style(height=550, scale=0.5)
600
- with gr.Column(visible=False) as modules_need_gpt3:
601
- chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(
602
- container=False)
603
- with gr.Row():
604
- clear_button_text = gr.Button(value="Clear Text", interactive=True)
605
- submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
606
-
607
- with gr.Column(scale=0.5):
608
- # TTS interface hidden initially
609
- with gr.Column(visible=False) as tts_interface:
610
- input_text = gr.Textbox(label="Text Prompt", value="Hello, World !, here is an example of light voice cloning. Try to upload your best audio samples quality")
611
- input_language = gr.Dropdown(label="Language", choices=["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn"], value="en")
612
- input_audio = gr.Audio(label="Reference Audio", type="filepath", value="examples/female.wav")
613
- input_mic = gr.Audio(source="microphone", type="filepath", label="Use Microphone for Reference")
614
- use_mic = gr.Checkbox(label="Check to use Microphone as Reference", value=False)
615
- agree = gr.Checkbox(label="Agree", value=True)
616
- output_waveform = gr.Video(label="Waveform Visual")
617
- output_audio = gr.HTML(label="Synthesised Audio")
618
-
619
- with gr.Row():
620
- submit_tts = gr.Button(value="Submit", interactive=True)
621
- clear_tts = gr.Button(value="Clear", interactive=True)
622
-
623
-
624
- def clear_tts_fields():
625
- return [gr.update(value=""), gr.update(value=""), None, None, gr.update(value=False), gr.update(value=True), None, None]
626
-
627
- submit_tts.click(
628
- tts.predict,
629
- inputs=[input_text, input_language, input_audio, input_mic, use_mic, agree],
630
- outputs=[output_waveform, output_audio],
631
- queue=True
632
- )
633
-
634
- clear_tts.click(
635
- clear_tts_fields,
636
- inputs=None,
637
- outputs=[input_text, input_language, input_audio, input_mic, use_mic, agree, output_waveform, output_audio],
638
- queue=False
639
- )
640
-
641
-
642
- openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
643
- outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
644
- modules_not_need_gpt2, tts_interface,module_key_input ,module_notification_box, text_refiner, visual_chatgpt, notification_box])
645
- enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
646
- outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
647
- modules_not_need_gpt,
648
- modules_not_need_gpt2, tts_interface,module_key_input,module_notification_box, text_refiner, visual_chatgpt, notification_box])
649
- disable_chatGPT_button.click(init_wo_openai_api_key,
650
- outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
651
- modules_not_need_gpt,
652
- modules_not_need_gpt2, tts_interface,module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
653
-
654
- enable_chatGPT_button.click(
655
- lambda: (None, [], [], [[], [], []], "", "", ""),
656
- [],
657
- [image_input, chatbot, state, click_state, paragraph_output, origin_image],
658
- queue=False,
659
- show_progress=False
660
- )
661
- openai_api_key.submit(
662
- lambda: (None, [], [], [[], [], []], "", "", ""),
663
- [],
664
- [image_input, chatbot, state, click_state, paragraph_output, origin_image],
665
- queue=False,
666
- show_progress=False
667
- )
668
-
669
- cap_everything_button.click(cap_everything, [origin_image, visual_chatgpt, text_refiner,input_language, input_audio, input_mic, use_mic, agree],
670
- [paragraph_output,output_waveform, output_audio])
671
-
672
- clear_button_click.click(
673
- lambda x: ([[], [], []], x),
674
- [origin_image],
675
- [click_state, image_input],
676
- queue=False,
677
- show_progress=False
678
- )
679
- clear_button_click.click(functools.partial(clear_chat_memory, keep_global=True), inputs=[visual_chatgpt])
680
- clear_button_image.click(
681
- lambda: (None, [], [], [[], [], []], "", "", ""),
682
- [],
683
- [image_input, chatbot, state, click_state, paragraph_output, origin_image],
684
- queue=False,
685
- show_progress=False
686
- )
687
- clear_button_image.click(clear_chat_memory, inputs=[visual_chatgpt])
688
- clear_button_text.click(
689
- lambda: ([], [], [[], [], [], []]),
690
- [],
691
- [chatbot, state, click_state],
692
- queue=False,
693
- show_progress=False
694
- )
695
- clear_button_text.click(clear_chat_memory, inputs=[visual_chatgpt])
696
-
697
- image_input.clear(
698
- lambda: (None, [], [], [[], [], []], "", "", ""),
699
- [],
700
- [image_input, chatbot, state, click_state, paragraph_output, origin_image],
701
- queue=False,
702
- show_progress=False
703
- )
704
-
705
- image_input.clear(clear_chat_memory, inputs=[visual_chatgpt])
706
-
707
- image_input_base.upload(upload_callback, [image_input_base, state, visual_chatgpt],
708
- [chatbot, state, origin_image, click_state, image_input_base, sketcher_input,
709
- image_embedding, original_size, input_size])
710
-
711
-
712
- image_input.upload(upload_callback, [image_input, state, visual_chatgpt],
713
- [chatbot, state, origin_image, click_state, image_input, sketcher_input,
714
- image_embedding, original_size, input_size])
715
- sketcher_input.upload(upload_callback, [sketcher_input, state, visual_chatgpt],
716
- [chatbot, state, origin_image, click_state, image_input, sketcher_input,
717
- image_embedding, original_size, input_size])
718
- chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state],
719
- [chatbot, state, aux_state])
720
- chat_input.submit(lambda: "", None, chat_input)
721
- submit_button_text.click(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state],
722
- [chatbot, state, aux_state])
723
- submit_button_text.click(lambda: "", None, chat_input)
724
- example_image.change(upload_callback, [example_image, state, visual_chatgpt],
725
- [chatbot, state, origin_image, click_state, image_input, sketcher_input,
726
- image_embedding, original_size, input_size])
727
- example_image.change(upload_callback, [example_image, state, visual_chatgpt],
728
- [chatbot, state, origin_image, click_state, image_input_base, sketcher_input,
729
- image_embedding, original_size, input_size])
730
- example_image.change(clear_chat_memory, inputs=[visual_chatgpt])
731
-
732
- def on_click_tab_selected():
733
- if gpt_state ==1:
734
- print(gpt_state)
735
- print("using gpt")
736
- return [gr.update(visible=True)]*2+[gr.update(visible=False)]*2
737
- else:
738
- print("no gpt")
739
- print("gpt_state",gpt_state)
740
- return [gr.update(visible=False)]+[gr.update(visible=True)]+[gr.update(visible=False)]*2
741
-
742
- def on_base_selected():
743
- if gpt_state ==1:
744
- print(gpt_state)
745
- print("using gpt")
746
- return [gr.update(visible=True)]*2+[gr.update(visible=False)]*2
747
- else:
748
- print("no gpt")
749
- return [gr.update(visible=False)]*4
750
-
751
-
752
- click_tab.select(on_click_tab_selected, outputs=[modules_need_gpt1,modules_not_need_gpt2,modules_need_gpt0,modules_need_gpt2])
753
- base_tab.select(on_base_selected, outputs=[modules_need_gpt0,modules_need_gpt2,modules_not_need_gpt2,modules_need_gpt1])
754
-
755
-
756
-
757
-
758
- image_input.select(
759
- inference_click,
760
- inputs=[
761
- origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
762
- image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
763
- out_state, click_index_state, input_mask_state, input_points_state, input_labels_state
764
- ],
765
- outputs=[chatbot, state, click_state, image_input, generated_caption, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state],
766
- show_progress=False, queue=True
767
- )
768
-
769
-
770
- submit_button_click.click(
771
- submit_caption,
772
- inputs=[
773
- image_input, state, generated_caption, text_refiner, visual_chatgpt, enable_wiki, length, sentiment, factuality, language,
774
- out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
775
- input_text, input_language, input_audio, input_mic, use_mic, agree
776
- ],
777
- outputs=[
778
- chatbot, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,
779
- output_waveform, output_audio
780
- ],
781
- show_progress=True,
782
- queue=True
783
- )
784
-
785
-
786
-
787
- submit_button_sketcher.click(
788
- inference_traject,
789
- inputs=[
790
- sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
791
- original_size, input_size, text_refiner
792
- ],
793
- outputs=[chatbot, state, sketcher_input],
794
- show_progress=False, queue=True
795
- )
796
-
797
- def update_output_audio():
798
- return gr.update(autoplay=True)
799
-
800
- output_audio.change(update_output_audio,outputs=[output_audio])
801
-
802
-
803
-
804
-
805
- return iface
806
-
807
-
808
- if __name__ == '__main__':
809
- iface = create_ui()
810
- iface.queue(concurrency_count=5, api_open=False, max_size=10)
811
- iface.launch(server_name="0.0.0.0", enable_queue=True)
 
1
  import os
2
+ import sys
3
+ from fastapi import Request
4
  import gradio as gr
5
+ from TTS.api import TTS
6
+ from TTS.utils.manage import ModelManager
7
+ from io import BytesIO
8
+ import base64
9
+
10
+ model_names = TTS().list_models()
11
+ print(model_names.__dict__)
12
+ print(model_names.__dir__())
13
+
14
+ os.environ["COQUI_TOS_AGREED"] = "1"
15
+
16
+ model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
17
+ tts = TTS(model_name, gpu=False)
18
+ tts.to("cuda")
19
+
20
+ def predict(prompt, language, audio_file_pth, mic_file_path, use_mic, agree):
21
+ if agree:
22
+ speaker_wav = mic_file_path if use_mic and mic_file_path else audio_file_pth
23
+
24
+ if not speaker_wav:
25
+ return None, "Please provide a reference audio."
26
+
27
+ if len(prompt) < 2:
28
+ return None, "Please provide a longer text prompt."
29
+
30
+ if len(prompt) > 10000:
31
+ return None, "Text length is limited to 10000 characters. Please try a shorter text."
32
+
33
+ try:
34
+ if language == "fr" and "your" in model_name:
35
+ language = "fr-fr"
36
+ if "/fr/" in model_name:
37
+ language = None
38
+
39
+ tts.tts_to_file(
40
+ text=prompt,
41
+ file_path="output.wav",
42
+ speaker_wav=speaker_wav,
43
+ language=language
44
+ )
45
+ except RuntimeError as e:
46
+ if "device-assert" in str(e):
47
+ return None, "Runtime error encountered. Please try again later."
48
+ else:
49
+ raise e
50
+
51
+ with open("output.wav", "rb") as audio_file:
52
+ audio_bytes = BytesIO(audio_file.read())
53
+ audio = base64.b64encode(audio_bytes.read()).decode("utf-8")
54
+ audio_player = f'<audio src="data:audio/wav;base64,{audio}" controls autoplay></audio>'
55
+ return gr.make_waveform(audio="output.wav"),audio_player
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  else:
57
+ return None, "Please accept the Terms & Conditions."