Ffftdtd5dtft commited on
Commit
a5d12c2
·
verified ·
1 Parent(s): 5196a69

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +616 -0
app.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import torch
4
+ from PIL import Image
5
+ from diffusers import (
6
+ StableDiffusionPipeline,
7
+ StableDiffusionImg2ImgPipeline,
8
+ FluxPipeline,
9
+ DiffusionPipeline,
10
+ )
11
+ from transformers import (
12
+ pipeline as transformers_pipeline,
13
+ AutoModelForCausalLM,
14
+ AutoTokenizer,
15
+ GPT2Tokenizer,
16
+ GPT2Model,
17
+ )
18
+ from audiocraft.models import musicgen
19
+ import gradio as gr
20
+ from huggingface_hub import snapshot_download, HfApi, HfFolder
21
+ import io
22
+ import time
23
+ from tqdm import tqdm
24
+ from google.cloud import storage
25
+ import json
26
+
27
+ hf_token = os.getenv("HF_TOKEN")
28
+ gcs_credentials = json.loads(os.getenv("GCS_CREDENTIALS"))
29
+ gcs_bucket_name = os.getenv("GCS_BUCKET_NAME")
30
+
31
+ HfFolder.save_token(hf_token)
32
+
33
+ storage_client = storage.Client.from_service_account_info(gcs_credentials)
34
+ bucket = storage_client.bucket(gcs_bucket_name)
35
+
36
+
37
+ def load_object_from_gcs(blob_name):
38
+ blob = bucket.blob(blob_name)
39
+ if blob.exists():
40
+ return pickle.loads(blob.download_as_bytes())
41
+ return None
42
+
43
+
44
+ def save_object_to_gcs(blob_name, obj):
45
+ blob = bucket.blob(blob_name)
46
+ blob.upload_from_string(pickle.dumps(obj))
47
+
48
+
49
+ def get_model_or_download(model_id, blob_name, loader_func):
50
+ model = load_object_from_gcs(blob_name)
51
+ if model:
52
+ return model
53
+ try:
54
+ with tqdm(total=1, desc=f"Downloading {model_id}") as pbar:
55
+ model = loader_func(model_id, torch_dtype=torch.float16)
56
+ pbar.update(1)
57
+ save_object_to_gcs(blob_name, model)
58
+ return model
59
+ except Exception as e:
60
+ print(f"Failed to load or save model: {e}")
61
+ return None
62
+
63
+
64
+ def generate_image(prompt):
65
+ blob_name = f"diffusers/generated_image:{prompt}"
66
+ image_bytes = load_object_from_gcs(blob_name)
67
+ if not image_bytes:
68
+ try:
69
+ with tqdm(total=1, desc="Generating image") as pbar:
70
+ image = text_to_image_pipeline(prompt).images[0]
71
+ pbar.update(1)
72
+ buffered = io.BytesIO()
73
+ image.save(buffered, format="JPEG")
74
+ image_bytes = buffered.getvalue()
75
+ save_object_to_gcs(blob_name, image_bytes)
76
+ except Exception as e:
77
+ print(f"Failed to generate image: {e}")
78
+ return None
79
+ return image_bytes
80
+
81
+
82
+ def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
83
+ blob_name = f"diffusers/edited_image:{prompt}:{strength}"
84
+ edited_image_bytes = load_object_from_gcs(blob_name)
85
+ if not edited_image_bytes:
86
+ try:
87
+ image = Image.open(io.BytesIO(image_bytes))
88
+ with tqdm(total=1, desc="Editing image") as pbar:
89
+ edited_image = img2img_pipeline(
90
+ prompt=prompt, image=image, strength=strength
91
+ ).images[0]
92
+ pbar.update(1)
93
+ buffered = io.BytesIO()
94
+ edited_image.save(buffered, format="JPEG")
95
+ edited_image_bytes = buffered.getvalue()
96
+ save_object_to_gcs(blob_name, edited_image_bytes)
97
+ except Exception as e:
98
+ print(f"Failed to edit image: {e}")
99
+ return None
100
+ return edited_image_bytes
101
+
102
+
103
+ def generate_song(prompt, duration=10):
104
+ blob_name = f"music/generated_song:{prompt}:{duration}"
105
+ song_bytes = load_object_from_gcs(blob_name)
106
+ if not song_bytes:
107
+ try:
108
+ with tqdm(total=1, desc="Generating song") as pbar:
109
+ song = music_gen(prompt, duration=duration)
110
+ pbar.update(1)
111
+ song_bytes = song[0].getvalue()
112
+ save_object_to_gcs(blob_name, song_bytes)
113
+ except Exception as e:
114
+ print(f"Failed to generate song: {e}")
115
+ return None
116
+ return song_bytes
117
+
118
+
119
+ def generate_text(prompt):
120
+ blob_name = f"transformers/generated_text:{prompt}"
121
+ text = load_object_from_gcs(blob_name)
122
+ if not text:
123
+ try:
124
+ with tqdm(total=1, desc="Generating text") as pbar:
125
+ text = text_gen_pipeline(prompt, max_new_tokens=256)[0][
126
+ "generated_text"
127
+ ].strip()
128
+ pbar.update(1)
129
+ save_object_to_gcs(blob_name, text)
130
+ except Exception as e:
131
+ print(f"Failed to generate text: {e}")
132
+ return None
133
+ return text
134
+
135
+
136
+ def generate_flux_image(prompt):
137
+ blob_name = f"diffusers/generated_flux_image:{prompt}"
138
+ flux_image_bytes = load_object_from_gcs(blob_name)
139
+ if not flux_image_bytes:
140
+ try:
141
+ with tqdm(total=1, desc="Generating FLUX image") as pbar:
142
+ flux_image = flux_pipeline(
143
+ prompt,
144
+ guidance_scale=0.0,
145
+ num_inference_steps=4,
146
+ max_length=256,
147
+ generator=torch.Generator("cpu").manual_seed(0),
148
+ ).images[0]
149
+ pbar.update(1)
150
+ buffered = io.BytesIO()
151
+ flux_image.save(buffered, format="JPEG")
152
+ flux_image_bytes = buffered.getvalue()
153
+ save_object_to_gcs(blob_name, flux_image_bytes)
154
+ except Exception as e:
155
+ print(f"Failed to generate flux image: {e}")
156
+ return None
157
+ return flux_image_bytes
158
+
159
+
160
+ def generate_code(prompt):
161
+ blob_name = f"transformers/generated_code:{prompt}"
162
+ code = load_object_from_gcs(blob_name)
163
+ if not code:
164
+ try:
165
+ with tqdm(total=1, desc="Generating code") as pbar:
166
+ inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to(
167
+ starcoder_model.device
168
+ )
169
+ outputs = starcoder_model.generate(inputs, max_new_tokens=256)
170
+ code = starcoder_tokenizer.decode(outputs[0])
171
+ pbar.update(1)
172
+ save_object_to_gcs(blob_name, code)
173
+ except Exception as e:
174
+ print(f"Failed to generate code: {e}")
175
+ return None
176
+ return code
177
+
178
+
179
+ def test_model_meta_llama():
180
+ blob_name = "transformers/meta_llama_test_response"
181
+ response = load_object_from_gcs(blob_name)
182
+ if not response:
183
+ try:
184
+ messages = [
185
+ {
186
+ "role": "system",
187
+ "content": "You are a pirate chatbot who always responds in pirate speak!",
188
+ },
189
+ {"role": "user", "content": "Who are you?"},
190
+ ]
191
+ with tqdm(total=1, desc="Testing Meta-Llama") as pbar:
192
+ response = meta_llama_pipeline(messages, max_new_tokens=256)[0][
193
+ "generated_text"
194
+ ].strip()
195
+ pbar.update(1)
196
+ save_object_to_gcs(blob_name, response)
197
+ except Exception as e:
198
+ print(f"Failed to test Meta-Llama: {e}")
199
+ return None
200
+ return response
201
+
202
+
203
+ def generate_image_sdxl(prompt):
204
+ blob_name = f"diffusers/generated_image_sdxl:{prompt}"
205
+ image_bytes = load_object_from_gcs(blob_name)
206
+ if not image_bytes:
207
+ try:
208
+ with tqdm(total=1, desc="Generating SDXL image") as pbar:
209
+ image = base(
210
+ prompt=prompt,
211
+ num_inference_steps=40,
212
+ denoising_end=0.8,
213
+ output_type="latent",
214
+ ).images
215
+ image = refiner(
216
+ prompt=prompt,
217
+ num_inference_steps=40,
218
+ denoising_start=0.8,
219
+ image=image,
220
+ ).images[0]
221
+ pbar.update(1)
222
+ buffered = io.BytesIO()
223
+ image.save(buffered, format="JPEG")
224
+ image_bytes = buffered.getvalue()
225
+ save_object_to_gcs(blob_name, image_bytes)
226
+ except Exception as e:
227
+ print(f"Failed to generate SDXL image: {e}")
228
+ return None
229
+ return image_bytes
230
+
231
+
232
+ def generate_musicgen_melody(prompt):
233
+ blob_name = f"music/generated_musicgen_melody:{prompt}"
234
+ song_bytes = load_object_from_gcs(blob_name)
235
+ if not song_bytes:
236
+ try:
237
+ with tqdm(total=1, desc="Generating MusicGen melody") as pbar:
238
+ melody, sr = torchaudio.load("./assets/bach.mp3")
239
+ wav = music_gen_melody.generate_with_chroma(
240
+ [prompt], melody[None].expand(3, -1, -1), sr
241
+ )
242
+ pbar.update(1)
243
+ song_bytes = wav[0].getvalue()
244
+ save_object_to_gcs(blob_name, song_bytes)
245
+ except Exception as e:
246
+ print(f"Failed to generate MusicGen melody: {e}")
247
+ return None
248
+ return song_bytes
249
+
250
+
251
+ def generate_musicgen_large(prompt):
252
+ blob_name = f"music/generated_musicgen_large:{prompt}"
253
+ song_bytes = load_object_from_gcs(blob_name)
254
+ if not song_bytes:
255
+ try:
256
+ with tqdm(total=1, desc="Generating MusicGen large") as pbar:
257
+ wav = music_gen_large.generate([prompt])
258
+ pbar.update(1)
259
+ song_bytes = wav[0].getvalue()
260
+ save_object_to_gcs(blob_name, song_bytes)
261
+ except Exception as e:
262
+ print(f"Failed to generate MusicGen large: {e}")
263
+ return None
264
+ return song_bytes
265
+
266
+
267
+ def transcribe_audio(audio_sample):
268
+ blob_name = f"transformers/transcribed_audio:{hash(audio_sample.tobytes())}"
269
+ text = load_object_from_gcs(blob_name)
270
+ if not text:
271
+ try:
272
+ with tqdm(total=1, desc="Transcribing audio") as pbar:
273
+ text = whisper_pipeline(audio_sample.copy(), batch_size=8)["text"]
274
+ pbar.update(1)
275
+ save_object_to_gcs(blob_name, text)
276
+ except Exception as e:
277
+ print(f"Failed to transcribe audio: {e}")
278
+ return None
279
+ return text
280
+
281
+
282
+ def generate_mistral_instruct(prompt):
283
+ blob_name = f"transformers/generated_mistral_instruct:{prompt}"
284
+ response = load_object_from_gcs(blob_name)
285
+ if not response:
286
+ try:
287
+ conversation = [{"role": "user", "content": prompt}]
288
+ with tqdm(total=1, desc="Generating Mistral Instruct response") as pbar:
289
+ inputs = mistral_instruct_tokenizer.apply_chat_template(
290
+ conversation,
291
+ tools=tools,
292
+ add_generation_prompt=True,
293
+ return_dict=True,
294
+ return_tensors="pt",
295
+ )
296
+ inputs.to(mistral_instruct_model.device)
297
+ outputs = mistral_instruct_model.generate(
298
+ **inputs, max_new_tokens=1000
299
+ )
300
+ response = mistral_instruct_tokenizer.decode(
301
+ outputs[0], skip_special_tokens=True
302
+ )
303
+ pbar.update(1)
304
+ save_object_to_gcs(blob_name, response)
305
+ except Exception as e:
306
+ print(f"Failed to generate Mistral Instruct response: {e}")
307
+ return None
308
+ return response
309
+
310
+
311
+ def generate_mistral_nemo(prompt):
312
+ blob_name = f"transformers/generated_mistral_nemo:{prompt}"
313
+ response = load_object_from_gcs(blob_name)
314
+ if not response:
315
+ try:
316
+ conversation = [{"role": "user", "content": prompt}]
317
+ with tqdm(total=1, desc="Generating Mistral Nemo response") as pbar:
318
+ inputs = mistral_nemo_tokenizer.apply_chat_template(
319
+ conversation,
320
+ tools=tools,
321
+ add_generation_prompt=True,
322
+ return_dict=True,
323
+ return_tensors="pt",
324
+ )
325
+ inputs.to(mistral_nemo_model.device)
326
+ outputs = mistral_nemo_model.generate(**inputs, max_new_tokens=1000)
327
+ response = mistral_nemo_tokenizer.decode(
328
+ outputs[0], skip_special_tokens=True
329
+ )
330
+ pbar.update(1)
331
+ save_object_to_gcs(blob_name, response)
332
+ except Exception as e:
333
+ print(f"Failed to generate Mistral Nemo response: {e}")
334
+ return None
335
+ return response
336
+
337
+
338
+ def generate_gpt2_xl(prompt):
339
+ blob_name = f"transformers/generated_gpt2_xl:{prompt}"
340
+ response = load_object_from_gcs(blob_name)
341
+ if not response:
342
+ try:
343
+ with tqdm(total=1, desc="Generating GPT-2 XL response") as pbar:
344
+ inputs = gpt2_xl_tokenizer(prompt, return_tensors="pt")
345
+ outputs = gpt2_xl_model(**inputs)
346
+ response = gpt2_xl_tokenizer.decode(
347
+ outputs[0][0], skip_special_tokens=True
348
+ )
349
+ pbar.update(1)
350
+ save_object_to_gcs(blob_name, response)
351
+ except Exception as e:
352
+ print(f"Failed to generate GPT-2 XL response: {e}")
353
+ return None
354
+ return response
355
+
356
+
357
+ def answer_question_minicpm(image_bytes, question):
358
+ blob_name = f"transformers/minicpm_answer:{hash(image_bytes)}:{question}"
359
+ answer = load_object_from_gcs(blob_name)
360
+ if not answer:
361
+ try:
362
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
363
+ with tqdm(total=1, desc="Answering question with MiniCPM") as pbar:
364
+ msgs = [{"role": "user", "content": [image, question]}]
365
+ answer = minicpm_model.chat(
366
+ image=None, msgs=msgs, tokenizer=minicpm_tokenizer
367
+ )
368
+ pbar.update(1)
369
+ save_object_to_gcs(blob_name, answer)
370
+ except Exception as e:
371
+ print(f"Failed to answer question with MiniCPM: {e}")
372
+ return None
373
+ return answer
374
+
375
+
376
+ def store_user_question(question):
377
+ blob_name = "user_questions.txt"
378
+ blob = bucket.blob(blob_name)
379
+ if blob.exists():
380
+ blob.download_to_filename("user_questions.txt")
381
+ with open("user_questions.txt", "a") as f:
382
+ f.write(question + "\n")
383
+ blob.upload_from_filename("user_questions.txt")
384
+
385
+
386
+ def retrain_models():
387
+ pass
388
+
389
+
390
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
391
+
392
+ text_to_image_pipeline = get_model_or_download(
393
+ "stabilityai/stable-diffusion-2",
394
+ "diffusers/text_to_image_model",
395
+ StableDiffusionPipeline.from_pretrained,
396
+ )
397
+ img2img_pipeline = get_model_or_download(
398
+ "CompVis/stable-diffusion-v1-4",
399
+ "diffusers/img2img_model",
400
+ StableDiffusionImg2ImgPipeline.from_pretrained,
401
+ )
402
+ flux_pipeline = get_model_or_download(
403
+ "black-forest-labs/FLUX.1-schnell",
404
+ "diffusers/flux_model",
405
+ FluxPipeline.from_pretrained,
406
+ )
407
+ text_gen_pipeline = transformers_pipeline(
408
+ "text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b"
409
+ )
410
+ music_gen = (
411
+ load_object_from_gcs("music/music_gen")
412
+ or musicgen.MusicGen.get_pretrained("melody")
413
+ )
414
+ meta_llama_pipeline = get_model_or_download(
415
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
416
+ "transformers/meta_llama_model",
417
+ transformers_pipeline,
418
+ )
419
+ starcoder_model = AutoModelForCausalLM.from_pretrained(
420
+ "bigcode/starcoder"
421
+ ).to(device)
422
+ starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")
423
+
424
+ base = DiffusionPipeline.from_pretrained(
425
+ "stabilityai/stable-diffusion-xl-base-1.0",
426
+ torch_dtype=torch.float16,
427
+ variant="fp16",
428
+ use_safetensors=True,
429
+ ).to(device)
430
+ refiner = DiffusionPipeline.from_pretrained(
431
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
432
+ text_encoder_2=base.text_encoder_2,
433
+ vae=base.vae,
434
+ torch_dtype=torch.float16,
435
+ use_safetensors=True,
436
+ variant="fp16",
437
+ ).to(device)
438
+ music_gen_melody = musicgen.MusicGen.get_pretrained("melody")
439
+ music_gen_melody.set_generation_params(duration=8)
440
+ music_gen_large = musicgen.MusicGen.get_pretrained("large")
441
+ music_gen_large.set_generation_params(duration=8)
442
+ whisper_pipeline = transformers_pipeline(
443
+ "automatic-speech-recognition",
444
+ model="openai/whisper-small",
445
+ chunk_length_s=30,
446
+ device=device,
447
+ )
448
+ mistral_instruct_model = AutoModelForCausalLM.from_pretrained(
449
+ "mistralai/Mistral-Large-Instruct-2407",
450
+ torch_dtype=torch.bfloat16,
451
+ device_map="auto",
452
+ )
453
+ mistral_instruct_tokenizer = AutoTokenizer.from_pretrained(
454
+ "mistralai/Mistral-Large-Instruct-2407"
455
+ )
456
+ mistral_nemo_model = AutoModelForCausalLM.from_pretrained(
457
+ "mistralai/Mistral-Nemo-Instruct-2407",
458
+ torch_dtype=torch.bfloat16,
459
+ device_map="auto",
460
+ )
461
+ mistral_nemo_tokenizer = AutoTokenizer.from_pretrained(
462
+ "mistralai/Mistral-Nemo-Instruct-2407"
463
+ )
464
+ gpt2_xl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-xl")
465
+ gpt2_xl_model = GPT2Model.from_pretrained("gpt2-xl")
466
+ minicpm_model = AutoModel.from_pretrained(
467
+ "openbmb/MiniCPM-V-2_6",
468
+ trust_remote_code=True,
469
+ attn_implementation="sdpa",
470
+ torch_dtype=torch.bfloat16,
471
+ ).eval().cuda()
472
+ minicpm_tokenizer = AutoTokenizer.from_pretrained(
473
+ "openbmb/MiniCPM-V-2_6", trust_remote_code=True
474
+ )
475
+
476
+ tools = []
477
+
478
+ gen_image_tab = gr.Interface(
479
+ fn=generate_image,
480
+ inputs=gr.Textbox(label="Prompt:"),
481
+ outputs=gr.Image(type="pil"),
482
+ title="Generate Image",
483
+ )
484
+ edit_image_tab = gr.Interface(
485
+ fn=edit_image_with_prompt,
486
+ inputs=[
487
+ gr.Image(type="pil", label="Image:"),
488
+ gr.Textbox(label="Prompt:"),
489
+ gr.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:"),
490
+ ],
491
+ outputs=gr.Image(type="pil"),
492
+ title="Edit Image",
493
+ )
494
+ generate_song_tab = gr.Interface(
495
+ fn=generate_song,
496
+ inputs=[
497
+ gr.Textbox(label="Prompt:"),
498
+ gr.Slider(5, 60, 10, step=1, label="Duration (s):"),
499
+ ],
500
+ outputs=gr.Audio(type="numpy"),
501
+ title="Generate Songs",
502
+ )
503
+ generate_text_tab = gr.Interface(
504
+ fn=generate_text,
505
+ inputs=gr.Textbox(label="Prompt:"),
506
+ outputs=gr.Textbox(label="Generated Text:"),
507
+ title="Generate Text",
508
+ )
509
+ generate_flux_image_tab = gr.Interface(
510
+ fn=generate_flux_image,
511
+ inputs=gr.Textbox(label="Prompt:"),
512
+ outputs=gr.Image(type="pil"),
513
+ title="Generate FLUX Images",
514
+ )
515
+ generate_code_tab = gr.Interface(
516
+ fn=generate_code,
517
+ inputs=gr.Textbox(label="Prompt:"),
518
+ outputs=gr.Textbox(label="Generated Code:"),
519
+ title="Generate Code",
520
+ )
521
+ model_meta_llama_test_tab = gr.Interface(
522
+ fn=test_model_meta_llama,
523
+ inputs=None,
524
+ outputs=gr.Textbox(label="Model Output:"),
525
+ title="Test Meta-Llama",
526
+ )
527
+ generate_image_sdxl_tab = gr.Interface(
528
+ fn=generate_image_sdxl,
529
+ inputs=gr.Textbox(label="Prompt:"),
530
+ outputs=gr.Image(type="pil"),
531
+ title="Generate SDXL Image",
532
+ )
533
+ generate_musicgen_melody_tab = gr.Interface(
534
+ fn=generate_musicgen_melody,
535
+ inputs=gr.Textbox(label="Prompt:"),
536
+ outputs=gr.Audio(type="numpy"),
537
+ title="Generate MusicGen Melody",
538
+ )
539
+ generate_musicgen_large_tab = gr.Interface(
540
+ fn=generate_musicgen_large,
541
+ inputs=gr.Textbox(label="Prompt:"),
542
+ outputs=gr.Audio(type="numpy"),
543
+ title="Generate MusicGen Large",
544
+ )
545
+ transcribe_audio_tab = gr.Interface(
546
+ fn=transcribe_audio,
547
+ inputs=gr.Audio(type="numpy", label="Audio Sample:"),
548
+ outputs=gr.Textbox(label="Transcribed Text:"),
549
+ title="Transcribe Audio",
550
+ )
551
+ generate_mistral_instruct_tab = gr.Interface(
552
+ fn=generate_mistral_instruct,
553
+ inputs=gr.Textbox(label="Prompt:"),
554
+ outputs=gr.Textbox(label="Mistral Instruct Response:"),
555
+ title="Generate Mistral Instruct Response",
556
+ )
557
+ generate_mistral_nemo_tab = gr.Interface(
558
+ fn=generate_mistral_nemo,
559
+ inputs=gr.Textbox(label="Prompt:"),
560
+ outputs=gr.Textbox(label="Mistral Nemo Response:"),
561
+ title="Generate Mistral Nemo Response",
562
+ )
563
+ generate_gpt2_xl_tab = gr.Interface(
564
+ fn=generate_gpt2_xl,
565
+ inputs=gr.Textbox(label="Prompt:"),
566
+ outputs=gr.Textbox(label="GPT-2 XL Response:"),
567
+ title="Generate GPT-2 XL Response",
568
+ )
569
+ answer_question_minicpm_tab = gr.Interface(
570
+ fn=answer_question_minicpm,
571
+ inputs=[
572
+ gr.Image(type="pil", label="Image:"),
573
+ gr.Textbox(label="Question:"),
574
+ ],
575
+ outputs=gr.Textbox(label="MiniCPM Answer:"),
576
+ title="Answer Question with MiniCPM",
577
+ )
578
+
579
+ app = gr.TabbedInterface(
580
+ [
581
+ gen_image_tab,
582
+ edit_image_tab,
583
+ generate_song_tab,
584
+ generate_text_tab,
585
+ generate_flux_image_tab,
586
+ generate_code_tab,
587
+ model_meta_llama_test_tab,
588
+ generate_image_sdxl_tab,
589
+ generate_musicgen_melody_tab,
590
+ generate_musicgen_large_tab,
591
+ transcribe_audio_tab,
592
+ generate_mistral_instruct_tab,
593
+ generate_mistral_nemo_tab,
594
+ generate_gpt2_xl_tab,
595
+ answer_question_minicpm_tab,
596
+ ],
597
+ [
598
+ "Generate Image",
599
+ "Edit Image",
600
+ "Generate Song",
601
+ "Generate Text",
602
+ "Generate FLUX Image",
603
+ "Generate Code",
604
+ "Test Meta-Llama",
605
+ "Generate SDXL Image",
606
+ "Generate MusicGen Melody",
607
+ "Generate MusicGen Large",
608
+ "Transcribe Audio",
609
+ "Generate Mistral Instruct Response",
610
+ "Generate Mistral Nemo Response",
611
+ "Generate GPT-2 XL Response",
612
+ "Answer Question with MiniCPM",
613
+ ],
614
+ )
615
+
616
+ app.launch(share=True)