not-lain commited on
Commit
cce19ac
·
1 Parent(s): 959fbde

major fallback

Browse files
Files changed (1) hide show
  1. app.py +17 -277
app.py CHANGED
@@ -3,36 +3,11 @@ import spaces
3
  import torch
4
  from loadimg import load_img
5
  from torchvision import transforms
6
- from transformers import (
7
- AutoModelForImageSegmentation,
8
- pipeline,
9
- MBartForConditionalGeneration,
10
- MBart50TokenizerFast,
11
- )
12
  from diffusers import FluxFillPipeline
13
  from PIL import Image, ImageOps
14
 
15
- # from sam2.sam2_image_predictor import SAM2ImagePredictor
16
- import numpy as np
17
- from simple_lama_inpainting import SimpleLama
18
- from contextlib import contextmanager
19
-
20
- # import whisperx
21
- import gc
22
-
23
-
24
- @contextmanager
25
- def float32_high_matmul_precision():
26
- torch.set_float32_matmul_precision("high")
27
- try:
28
- yield
29
- finally:
30
- torch.set_float32_matmul_precision("highest")
31
-
32
-
33
- pipe = FluxFillPipeline.from_pretrained(
34
- "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
35
- ).to("cuda")
36
 
37
  birefnet = AutoModelForImageSegmentation.from_pretrained(
38
  "ZhengPeng7/BiRefNet", trust_remote_code=True
@@ -47,6 +22,10 @@ transform_image = transforms.Compose(
47
  ]
48
  )
49
 
 
 
 
 
50
 
51
  def prepare_image_and_mask(
52
  image,
@@ -131,10 +110,9 @@ def rmbg(image=None, url=None):
131
  image = load_img(image).convert("RGB")
132
  image_size = image.size
133
  input_images = transform_image(image).unsqueeze(0).to("cuda")
134
- with float32_high_matmul_precision():
135
- # Prediction
136
- with torch.no_grad():
137
- preds = birefnet(input_images)[-1].sigmoid().cpu()
138
  pred = preds[0].squeeze()
139
  pred_pil = transforms.ToPILImage()(pred)
140
  mask = pred_pil.resize(image_size)
@@ -142,129 +120,7 @@ def rmbg(image=None, url=None):
142
  return image
143
 
144
 
145
- # def mask_generation(image=None, d=None):
146
- # # use bfloat16 for the entire notebook
147
- # # torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
148
- # # # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
149
- # # if torch.cuda.get_device_properties(0).major >= 8:
150
- # # torch.backends.cuda.matmul.allow_tf32 = True
151
- # # torch.backends.cudnn.allow_tf32 = True
152
- # d = eval(d) # convert this to dictionary
153
- # with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
154
- # predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
155
- # predictor.set_image(image)
156
- # input_point = np.array(d["input_points"])
157
- # input_label = np.array(d["input_labels"])
158
- # masks, scores, logits = predictor.predict(
159
- # point_coords=input_point,
160
- # point_labels=input_label,
161
- # multimask_output=True,
162
- # )
163
- # sorted_ind = np.argsort(scores)[::-1]
164
- # masks = masks[sorted_ind]
165
- # scores = scores[sorted_ind]
166
- # logits = logits[sorted_ind]
167
-
168
- # out = []
169
- # for i in range(len(masks)):
170
- # m = Image.fromarray(masks[i] * 255).convert("L")
171
- # comp = Image.composite(image, m, m)
172
- # out.append((comp, f"image {i}"))
173
-
174
- # return out
175
-
176
-
177
- def erase(image=None, mask=None):
178
- simple_lama = SimpleLama()
179
- image = load_img(image)
180
- mask = load_img(mask).convert("L")
181
- return simple_lama(image, mask)
182
-
183
-
184
- # def transcribe(audio):
185
- # if audio is None:
186
- # raise gr.Error("No audio file submitted!")
187
-
188
- # device = "cuda" if torch.cuda.is_available() else "cpu"
189
- # compute_type = "float16"
190
- # batch_size = 8 # reduced batch size to be conservative with memory
191
-
192
- # try:
193
- # # 1. Load model and transcribe
194
- # model = whisperx.load_model("large-v2", device, compute_type=compute_type)
195
- # audio_input = whisperx.load_audio(audio)
196
- # result = model.transcribe(audio_input, batch_size=batch_size)
197
-
198
- # # Clear GPU memory
199
- # del model
200
- # gc.collect()
201
- # torch.cuda.empty_cache()
202
-
203
- # # 2. Align whisper output
204
- # model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
205
- # result = whisperx.align(result["segments"], model_a, metadata, audio_input, device, return_char_alignments=False)
206
-
207
- # # Clear GPU memory
208
- # del model_a
209
- # gc.collect()
210
- # torch.cuda.empty_cache()
211
-
212
- # # 3. Assign speaker labels
213
- # diarize_model = whisperx.DiarizationPipeline(device=device)
214
- # diarize_segments = diarize_model(audio_input)
215
-
216
- # # Combine transcription with speaker diarization
217
- # result = whisperx.assign_word_speakers(diarize_segments, result)
218
-
219
- # # Format output with speaker labels and timestamps
220
- # formatted_text = []
221
- # for segment in result["segments"]:
222
- # if not isinstance(segment, dict):
223
- # continue
224
-
225
- # speaker = f"[Speaker {segment.get('speaker', 'Unknown')}]"
226
- # start_time = f"{float(segment.get('start', 0)):.2f}"
227
- # end_time = f"{float(segment.get('end', 0)):.2f}"
228
- # text = segment.get('text', '').strip()
229
- # formatted_text.append(f"[{start_time}s - {end_time}s] {speaker}: {text}")
230
-
231
- # return "\n".join(formatted_text)
232
-
233
- # except Exception as e:
234
- # raise gr.Error(f"Transcription failed: {str(e)}")
235
- # finally:
236
- # # Ensure GPU memory is cleared even if an error occurs
237
- # gc.collect()
238
- # torch.cuda.empty_cache()
239
-
240
-
241
- def translate_text(text, source_lang, target_lang):
242
- model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
243
- tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
244
-
245
- # Set source language
246
- tokenizer.src_lang = source_lang
247
-
248
- # Encode the input text
249
- encoded_text = tokenizer(text, return_tensors="pt")
250
-
251
- # Generate translation
252
- generated_tokens = model.generate(
253
- **encoded_text,
254
- forced_bos_token_id=tokenizer.lang_code_to_id[target_lang]
255
- )
256
-
257
- # Decode the generated tokens
258
- translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
259
-
260
- # Clear GPU memory
261
- del model
262
- gc.collect()
263
- torch.cuda.empty_cache()
264
-
265
- return translation
266
-
267
- @spaces.GPU(duration=120)
268
  def main(*args):
269
  api_num = args[0]
270
  args = args[1:]
@@ -274,20 +130,12 @@ def main(*args):
274
  return outpaint(*args)
275
  elif api_num == 3:
276
  return inpaint(*args)
277
- # elif api_num == 4:
278
- # return mask_generation(*args)
279
- elif api_num == 5:
280
- return erase(*args)
281
- # elif api_num == 6:
282
- # return transcribe(*args)
283
- elif api_num == 7:
284
- return translate_text(*args)
285
 
286
 
287
  rmbg_tab = gr.Interface(
288
  fn=main,
289
  inputs=[
290
- gr.Number(1, interactive=False),
291
  "image",
292
  gr.Text("", label="url"),
293
  ],
@@ -301,7 +149,7 @@ rmbg_tab = gr.Interface(
301
  outpaint_tab = gr.Interface(
302
  fn=main,
303
  inputs=[
304
- gr.Number(2, interactive=False),
305
  gr.Image(label="image", type="pil"),
306
  gr.Number(label="padding top"),
307
  gr.Number(label="padding bottom"),
@@ -321,7 +169,7 @@ outpaint_tab = gr.Interface(
321
  inpaint_tab = gr.Interface(
322
  fn=main,
323
  inputs=[
324
- gr.Number(3, interactive=False),
325
  gr.Image(label="image", type="pil"),
326
  gr.Image(label="mask", type="pil"),
327
  gr.Text(label="prompt"),
@@ -335,119 +183,11 @@ inpaint_tab = gr.Interface(
335
  description="it is recommended that you use https://github.com/la-voliere/react-mask-editor when creating an image mask in JS and then inverse it before sending it to this space",
336
  )
337
 
338
-
339
- # sam2_tab = gr.Interface(
340
- # main,
341
- # inputs=[
342
- # gr.Number(4, interactive=False),
343
- # gr.Image(type="pil"),
344
- # gr.Text(),
345
- # ],
346
- # outputs=gr.Gallery(),
347
- # examples=[
348
- # [
349
- # 4,
350
- # "./assets/truck.jpg",
351
- # '{"input_points": [[500, 375], [1125, 625]], "input_labels": [1, 0]}',
352
- # ]
353
- # ],
354
- # api_name="sam2",
355
- # cache_examples=False,
356
- # )
357
-
358
- erase_tab = gr.Interface(
359
- main,
360
- inputs=[
361
- gr.Number(5, interactive=False),
362
- gr.Image(type="pil"),
363
- gr.Image(type="pil"),
364
- ],
365
- outputs=gr.Image(),
366
- examples=[
367
- [
368
- 5,
369
- "./assets/rocket.png",
370
- "./assets/Inpainting mask.png",
371
- ]
372
- ],
373
- api_name="erase",
374
- cache_examples=False,
375
- )
376
-
377
- transcribe_tab = gr.Interface(
378
- fn=main,
379
- inputs=[
380
- gr.Number(value=6, interactive=False), # API number
381
- gr.Audio(type="filepath", label="Audio File"),
382
- ],
383
- outputs=gr.Textbox(label="Transcription"),
384
- title="Audio Transcription",
385
- description="Upload an audio file to extract text using WhisperX with speaker diarization",
386
- api_name="transcribe",
387
- examples=[],
388
- )
389
-
390
- translate_tab = gr.Interface(
391
- fn=main,
392
- inputs=[
393
- gr.Number(value=7, interactive=False),
394
- gr.Textbox(label="Text to translate"),
395
- gr.Dropdown(
396
- choices=[
397
- "ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX",
398
- "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV",
399
- "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN",
400
- "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID",
401
- "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF",
402
- "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA",
403
- "ur_PK", "xh_ZA", "gl_ES", "sl_SI"
404
- ],
405
- label="Source Language",
406
- value="en_XX"
407
- ),
408
- gr.Dropdown(
409
- choices=[
410
- "ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX",
411
- "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV",
412
- "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN",
413
- "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID",
414
- "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF",
415
- "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA",
416
- "ur_PK", "xh_ZA", "gl_ES", "sl_SI"
417
- ],
418
- label="Target Language",
419
- value="fr_XX"
420
- ),
421
- ],
422
- outputs=gr.Textbox(label="Translated Text"),
423
- title="Text Translation",
424
- description="Translate text between multiple languages using mBART-50",
425
- api_name="translate",
426
- examples=[
427
- [7, "Hello, how are you?", "en_XX", "fr_XX"],
428
- [7, "Bonjour, comment allez-vous?", "fr_XX", "en_XX"]
429
- ],
430
- cache_examples=False,
431
- )
432
-
433
  demo = gr.TabbedInterface(
434
- [
435
- rmbg_tab,
436
- outpaint_tab,
437
- inpaint_tab,
438
- erase_tab,
439
- transcribe_tab,
440
- translate_tab
441
- ],
442
- [
443
- "remove background",
444
- "outpainting",
445
- "inpainting",
446
- "erase",
447
- "transcribe",
448
- "translate"
449
- ],
450
  title="Utilities that require GPU",
451
  )
452
 
453
- demo.launch()
 
 
3
  import torch
4
  from loadimg import load_img
5
  from torchvision import transforms
6
+ from transformers import AutoModelForImageSegmentation
 
 
 
 
 
7
  from diffusers import FluxFillPipeline
8
  from PIL import Image, ImageOps
9
 
10
+ torch.set_float32_matmul_precision(["high", "highest"][0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  birefnet = AutoModelForImageSegmentation.from_pretrained(
13
  "ZhengPeng7/BiRefNet", trust_remote_code=True
 
22
  ]
23
  )
24
 
25
+ pipe = FluxFillPipeline.from_pretrained(
26
+ "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
27
+ ).to("cuda")
28
+
29
 
30
  def prepare_image_and_mask(
31
  image,
 
110
  image = load_img(image).convert("RGB")
111
  image_size = image.size
112
  input_images = transform_image(image).unsqueeze(0).to("cuda")
113
+ # Prediction
114
+ with torch.no_grad():
115
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
 
116
  pred = preds[0].squeeze()
117
  pred_pil = transforms.ToPILImage()(pred)
118
  mask = pred_pil.resize(image_size)
 
120
  return image
121
 
122
 
123
+ @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def main(*args):
125
  api_num = args[0]
126
  args = args[1:]
 
130
  return outpaint(*args)
131
  elif api_num == 3:
132
  return inpaint(*args)
 
 
 
 
 
 
 
 
133
 
134
 
135
  rmbg_tab = gr.Interface(
136
  fn=main,
137
  inputs=[
138
+ gr.Number(1, visible=False),
139
  "image",
140
  gr.Text("", label="url"),
141
  ],
 
149
  outpaint_tab = gr.Interface(
150
  fn=main,
151
  inputs=[
152
+ gr.Number(2, visible=False),
153
  gr.Image(label="image", type="pil"),
154
  gr.Number(label="padding top"),
155
  gr.Number(label="padding bottom"),
 
169
  inpaint_tab = gr.Interface(
170
  fn=main,
171
  inputs=[
172
+ gr.Number(3, visible=False),
173
  gr.Image(label="image", type="pil"),
174
  gr.Image(label="mask", type="pil"),
175
  gr.Text(label="prompt"),
 
183
  description="it is recommended that you use https://github.com/la-voliere/react-mask-editor when creating an image mask in JS and then inverse it before sending it to this space",
184
  )
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  demo = gr.TabbedInterface(
187
+ [rmbg_tab, outpaint_tab, inpaint_tab],
188
+ ["remove background", "outpainting", "inpainting"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  title="Utilities that require GPU",
190
  )
191
 
192
+
193
+ demo.launch()