not-lain commited on
Commit
dc95e97
·
1 Parent(s): b21d0d9

soft reset

Browse files
Files changed (1) hide show
  1. app.py +368 -232
app.py CHANGED
@@ -1,19 +1,21 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
4
- from loadimg import load_img
5
  from torchvision import transforms
6
  from transformers import AutoModelForImageSegmentation, pipeline
7
  from diffusers import FluxFillPipeline
8
  from PIL import Image, ImageOps
9
-
10
- # from sam2.sam2_image_predictor import SAM2ImagePredictor
11
  import numpy as np
12
  from simple_lama_inpainting import SimpleLama
13
  from contextlib import contextmanager
14
- # import whisperx
15
  import gc
16
 
 
 
 
 
 
17
  @contextmanager
18
  def float32_high_matmul_precision():
19
  torch.set_float32_matmul_precision("high")
@@ -23,14 +25,33 @@ def float32_high_matmul_precision():
23
  torch.set_float32_matmul_precision("highest")
24
 
25
 
26
- pipe = FluxFillPipeline.from_pretrained(
27
- "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
28
- ).to("cuda")
 
 
 
29
 
30
- birefnet = AutoModelForImageSegmentation.from_pretrained(
31
- "ZhengPeng7/BiRefNet", trust_remote_code=True
32
- )
33
- birefnet.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  transform_image = transforms.Compose(
36
  [
@@ -49,7 +70,6 @@ def prepare_image_and_mask(
49
  padding_right=0,
50
  ):
51
  image = load_img(image).convert("RGB")
52
- # expand image (left,top,right,bottom)
53
  background = ImageOps.expand(
54
  image,
55
  border=(padding_left, padding_top, padding_right, padding_bottom),
@@ -77,19 +97,19 @@ def outpaint(
77
  background, mask = prepare_image_and_mask(
78
  image, padding_top, padding_bottom, padding_left, padding_right
79
  )
80
-
81
- result = pipe(
82
- prompt=prompt,
83
- height=background.height,
84
- width=background.width,
85
- image=background,
86
- mask_image=mask,
87
- num_inference_steps=num_inference_steps,
88
- guidance_scale=guidance_scale,
89
- ).images[0]
90
-
 
91
  result = result.convert("RGBA")
92
-
93
  return result
94
 
95
 
@@ -102,275 +122,391 @@ def inpaint(
102
  ):
103
  background = image.convert("RGB")
104
  mask = mask.convert("L")
105
-
106
- result = pipe(
107
- prompt=prompt,
108
- height=background.height,
109
- width=background.width,
110
- image=background,
111
- mask_image=mask,
112
- num_inference_steps=num_inference_steps,
113
- guidance_scale=guidance_scale,
114
- ).images[0]
115
-
 
116
  result = result.convert("RGBA")
117
-
118
  return result
119
 
120
 
121
  def rmbg(image=None, url=None):
122
- if image is None:
123
- image = url
124
- image = load_img(image).convert("RGB")
125
- image_size = image.size
126
- input_images = transform_image(image).unsqueeze(0).to("cuda")
 
 
 
 
 
 
 
 
 
 
127
  with float32_high_matmul_precision():
128
- # Prediction
129
  with torch.no_grad():
130
  preds = birefnet(input_images)[-1].sigmoid().cpu()
131
  pred = preds[0].squeeze()
132
  pred_pil = transforms.ToPILImage()(pred)
133
  mask = pred_pil.resize(image_size)
134
- image.putalpha(mask)
135
- return image
136
-
137
-
138
- # def mask_generation(image=None, d=None):
139
- # # use bfloat16 for the entire notebook
140
- # # torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
141
- # # # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
142
- # # if torch.cuda.get_device_properties(0).major >= 8:
143
- # # torch.backends.cuda.matmul.allow_tf32 = True
144
- # # torch.backends.cudnn.allow_tf32 = True
145
- # d = eval(d) # convert this to dictionary
146
- # with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
147
- # predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
148
- # predictor.set_image(image)
149
- # input_point = np.array(d["input_points"])
150
- # input_label = np.array(d["input_labels"])
151
- # masks, scores, logits = predictor.predict(
152
- # point_coords=input_point,
153
- # point_labels=input_label,
154
- # multimask_output=True,
155
- # )
156
- # sorted_ind = np.argsort(scores)[::-1]
157
- # masks = masks[sorted_ind]
158
- # scores = scores[sorted_ind]
159
- # logits = logits[sorted_ind]
160
-
161
- # out = []
162
- # for i in range(len(masks)):
163
- # m = Image.fromarray(masks[i] * 255).convert("L")
164
- # comp = Image.composite(image, m, m)
165
- # out.append((comp, f"image {i}"))
166
-
167
- # return out
168
 
169
 
170
  def erase(image=None, mask=None):
171
- simple_lama = SimpleLama()
172
- image = load_img(image)
173
- mask = load_img(mask).convert("L")
174
- return simple_lama(image, mask)
175
-
176
-
177
- # def transcribe(audio):
178
- # if audio is None:
179
- # raise gr.Error("No audio file submitted!")
180
-
181
- # device = "cuda" if torch.cuda.is_available() else "cpu"
182
- # compute_type = "float16"
183
- # batch_size = 8 # reduced batch size to be conservative with memory
184
-
185
- # try:
186
- # # 1. Load model and transcribe
187
- # model = whisperx.load_model("large-v2", device, compute_type=compute_type)
188
- # audio_input = whisperx.load_audio(audio)
189
- # result = model.transcribe(audio_input, batch_size=batch_size)
190
-
191
- # # Clear GPU memory
192
- # del model
193
- # gc.collect()
194
- # torch.cuda.empty_cache()
195
-
196
- # # 2. Align whisper output
197
- # model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
198
- # result = whisperx.align(result["segments"], model_a, metadata, audio_input, device, return_char_alignments=False)
199
-
200
- # # Clear GPU memory
201
- # del model_a
202
- # gc.collect()
203
- # torch.cuda.empty_cache()
204
-
205
- # # 3. Assign speaker labels
206
- # diarize_model = whisperx.DiarizationPipeline(device=device)
207
- # diarize_segments = diarize_model(audio_input)
208
-
209
- # # Combine transcription with speaker diarization
210
- # result = whisperx.assign_word_speakers(diarize_segments, result)
211
-
212
- # # Format output with speaker labels and timestamps
213
- # formatted_text = []
214
- # for segment in result["segments"]:
215
- # if not isinstance(segment, dict):
216
- # continue
217
-
218
- # speaker = f"[Speaker {segment.get('speaker', 'Unknown')}]"
219
- # start_time = f"{float(segment.get('start', 0)):.2f}"
220
- # end_time = f"{float(segment.get('end', 0)):.2f}"
221
- # text = segment.get('text', '').strip()
222
- # formatted_text.append(f"[{start_time}s - {end_time}s] {speaker}: {text}")
223
-
224
- # return "\n".join(formatted_text)
225
-
226
- # except Exception as e:
227
- # raise gr.Error(f"Transcription failed: {str(e)}")
228
- # finally:
229
- # # Ensure GPU memory is cleared even if an error occurs
230
- # gc.collect()
231
- # torch.cuda.empty_cache()
232
-
233
-
234
- @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  def main(*args):
236
  api_num = args[0]
237
  args = args[1:]
238
- if api_num == 1:
239
- return rmbg(*args)
240
- elif api_num == 2:
241
- return outpaint(*args)
242
- elif api_num == 3:
243
- return inpaint(*args)
244
- # elif api_num == 4:
245
- # return mask_generation(*args)
246
- elif api_num == 5:
247
- return erase(*args)
248
- # elif api_num == 6:
249
- # return transcribe(*args)
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
 
252
  rmbg_tab = gr.Interface(
253
  fn=main,
254
  inputs=[
255
- gr.Number(1, interactive=False),
256
- "image",
257
- gr.Text("", label="url"),
258
  ],
259
- outputs=["image"],
 
 
260
  api_name="rmbg",
261
- examples=[[1, "./assets/Inpainting mask.png", ""]],
262
  cache_examples=False,
263
- description="pass an image or a url of an image",
264
  )
265
 
266
  outpaint_tab = gr.Interface(
267
  fn=main,
268
  inputs=[
269
- gr.Number(2, interactive=False),
270
- gr.Image(label="image", type="pil"),
271
- gr.Number(label="padding top"),
272
- gr.Number(label="padding bottom"),
273
- gr.Number(label="padding left"),
274
- gr.Number(label="padding right"),
275
- gr.Text(label="prompt"),
276
- gr.Number(value=50, label="num_inference_steps"),
277
- gr.Number(value=28, label="guidance_scale"),
 
 
 
 
 
 
 
278
  ],
279
- outputs=["image"],
 
 
280
  api_name="outpainting",
281
- examples=[[2, "./assets/rocket.png", 100, 0, 0, 0, "", 50, 28]],
282
  cache_examples=False,
283
  )
284
 
285
-
286
  inpaint_tab = gr.Interface(
287
  fn=main,
288
  inputs=[
289
- gr.Number(3, interactive=False),
290
- gr.Image(label="image", type="pil"),
291
- gr.Image(label="mask", type="pil"),
292
- gr.Text(label="prompt"),
293
- gr.Number(value=50, label="num_inference_steps"),
294
- gr.Number(value=28, label="guidance_scale"),
 
 
 
 
 
 
295
  ],
296
- outputs=["image"],
 
 
297
  api_name="inpaint",
298
- examples=[[3, "./assets/rocket.png", "./assets/Inpainting mask.png"]],
299
  cache_examples=False,
300
- description="it is recommended that you use https://github.com/la-voliere/react-mask-editor when creating an image mask in JS and then inverse it before sending it to this space",
301
  )
302
 
303
-
304
- # sam2_tab = gr.Interface(
305
- # main,
306
- # inputs=[
307
- # gr.Number(4, interactive=False),
308
- # gr.Image(type="pil"),
309
- # gr.Text(),
310
- # ],
311
- # outputs=gr.Gallery(),
312
- # examples=[
313
- # [
314
- # 4,
315
- # "./assets/truck.jpg",
316
- # '{"input_points": [[500, 375], [1125, 625]], "input_labels": [1, 0]}',
317
- # ]
318
- # ],
319
- # api_name="sam2",
320
- # cache_examples=False,
321
- # )
322
-
323
  erase_tab = gr.Interface(
324
- main,
325
  inputs=[
326
- gr.Number(5, interactive=False),
327
- gr.Image(type="pil"),
328
- gr.Image(type="pil"),
329
- ],
330
- outputs=gr.Image(),
331
- examples=[
332
- [
333
- 5,
334
- "./assets/rocket.png",
335
- "./assets/Inpainting mask.png",
336
- ]
337
  ],
 
 
 
338
  api_name="erase",
 
339
  cache_examples=False,
340
  )
341
 
342
- transcribe_tab = gr.Interface(
343
- fn=main,
344
- inputs=[
345
- gr.Number(value=6, interactive=False), # API number
346
- gr.Audio(type="filepath", label="Audio File"),
347
- ],
348
- outputs=gr.Textbox(label="Transcription"),
349
- title="Audio Transcription",
350
- description="Upload an audio file to extract text using WhisperX with speaker diarization",
351
- api_name="transcribe",
352
- examples=[]
353
- )
354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  demo = gr.TabbedInterface(
356
  [
357
  rmbg_tab,
358
  outpaint_tab,
359
  inpaint_tab,
360
- # sam2_tab,
361
  erase_tab,
362
- transcribe_tab,
 
363
  ],
364
  [
365
- "remove background",
366
- "outpainting",
367
- "inpainting",
 
 
368
  # "sam2",
369
- "erase",
370
- # "transcribe",
371
  ],
372
- title="Utilities that require GPU",
373
  )
374
 
375
-
376
  demo.launch()
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ from loadimg import load_img # Assuming loadimg.py exists with load_img function
5
  from torchvision import transforms
6
  from transformers import AutoModelForImageSegmentation, pipeline
7
  from diffusers import FluxFillPipeline
8
  from PIL import Image, ImageOps
 
 
9
  import numpy as np
10
  from simple_lama_inpainting import SimpleLama
11
  from contextlib import contextmanager
 
12
  import gc
13
 
14
+ # --- Add Translation Imports ---
15
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
16
+
17
+
18
+ # --- Utility Functions ---
19
  @contextmanager
20
  def float32_high_matmul_precision():
21
  torch.set_float32_matmul_precision("high")
 
25
  torch.set_float32_matmul_precision("highest")
26
 
27
 
28
+ # --- Model Loading ---
29
+ # Use context manager for precision during model loading if needed
30
+ with float32_high_matmul_precision():
31
+ pipe = FluxFillPipeline.from_pretrained(
32
+ "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
33
+ ).to("cuda")
34
 
35
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
36
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
37
+ ).to("cuda")
38
+
39
+ simple_lama = SimpleLama() # Initialize Lama globally if used often
40
+
41
+ # --- Translation Model and Tokenizer Loading ---
42
+ translation_model_name = "facebook/mbart-large-50-many-to-many-mmt"
43
+ try:
44
+ translation_model = MBartForConditionalGeneration.from_pretrained(
45
+ translation_model_name
46
+ ).to("cuda") # Move to GPU
47
+ translation_tokenizer = MBart50TokenizerFast.from_pretrained(translation_model_name)
48
+ except Exception as e:
49
+ print(f"Error loading translation model/tokenizer: {e}")
50
+ # Consider exiting or disabling the translation tab if loading fails
51
+ translation_model = None
52
+ translation_tokenizer = None
53
+
54
+ # --- Image Processing Functions ---
55
 
56
  transform_image = transforms.Compose(
57
  [
 
70
  padding_right=0,
71
  ):
72
  image = load_img(image).convert("RGB")
 
73
  background = ImageOps.expand(
74
  image,
75
  border=(padding_left, padding_top, padding_right, padding_bottom),
 
97
  background, mask = prepare_image_and_mask(
98
  image, padding_top, padding_bottom, padding_left, padding_right
99
  )
100
+ with (
101
+ float32_high_matmul_precision()
102
+ ): # Apply precision context if needed for inference
103
+ result = pipe(
104
+ prompt=prompt,
105
+ height=background.height,
106
+ width=background.width,
107
+ image=background,
108
+ mask_image=mask,
109
+ num_inference_steps=num_inference_steps,
110
+ guidance_scale=guidance_scale,
111
+ ).images[0]
112
  result = result.convert("RGBA")
 
113
  return result
114
 
115
 
 
122
  ):
123
  background = image.convert("RGB")
124
  mask = mask.convert("L")
125
+ with (
126
+ float32_high_matmul_precision()
127
+ ): # Apply precision context if needed for inference
128
+ result = pipe(
129
+ prompt=prompt,
130
+ height=background.height,
131
+ width=background.width,
132
+ image=background,
133
+ mask_image=mask,
134
+ num_inference_steps=num_inference_steps,
135
+ guidance_scale=guidance_scale,
136
+ ).images[0]
137
  result = result.convert("RGBA")
 
138
  return result
139
 
140
 
141
  def rmbg(image=None, url=None):
142
+ if image is None and url:
143
+ # Basic check for URL format, improve as needed
144
+ if not url.startswith(("http://", "https://")):
145
+ return "Invalid URL provided."
146
+ image = url # load_img should handle URLs if configured correctly
147
+ elif image is None:
148
+ return "Please provide an image or a URL."
149
+
150
+ try:
151
+ image_pil = load_img(image).convert("RGB")
152
+ except Exception as e:
153
+ return f"Error loading image: {e}"
154
+
155
+ image_size = image_pil.size
156
+ input_images = transform_image(image_pil).unsqueeze(0).to("cuda")
157
  with float32_high_matmul_precision():
 
158
  with torch.no_grad():
159
  preds = birefnet(input_images)[-1].sigmoid().cpu()
160
  pred = preds[0].squeeze()
161
  pred_pil = transforms.ToPILImage()(pred)
162
  mask = pred_pil.resize(image_size)
163
+ image_pil.putalpha(mask)
164
+ # Clean up GPU memory if needed
165
+ del input_images, preds, pred
166
+ torch.cuda.empty_cache()
167
+ gc.collect()
168
+ return image_pil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  def erase(image=None, mask=None):
172
+ if image is None or mask is None:
173
+ return "Please provide both an image and a mask."
174
+ try:
175
+ image_pil = load_img(image)
176
+ mask_pil = load_img(mask).convert("L")
177
+ result = simple_lama(image_pil, mask_pil)
178
+ # Clean up
179
+ gc.collect()
180
+ return result
181
+ except Exception as e:
182
+ return f"Error during erase operation: {e}"
183
+
184
+
185
+ # --- Translation Functionality ---
186
+
187
+ # Language Mapping
188
+ lang_data = {
189
+ "Arabic": "ar_AR",
190
+ "Czech": "cs_CZ",
191
+ "German": "de_DE",
192
+ "English": "en_XX",
193
+ "Spanish": "es_XX",
194
+ "Estonian": "et_EE",
195
+ "Finnish": "fi_FI",
196
+ "French": "fr_XX",
197
+ "Gujarati": "gu_IN",
198
+ "Hindi": "hi_IN",
199
+ "Italian": "it_IT",
200
+ "Japanese": "ja_XX",
201
+ "Kazakh": "kk_KZ",
202
+ "Korean": "ko_KR",
203
+ "Lithuanian": "lt_LT",
204
+ "Latvian": "lv_LV",
205
+ "Burmese": "my_MM",
206
+ "Nepali": "ne_NP",
207
+ "Dutch": "nl_XX",
208
+ "Romanian": "ro_RO",
209
+ "Russian": "ru_RU",
210
+ "Sinhala": "si_LK",
211
+ "Turkish": "tr_TR",
212
+ "Vietnamese": "vi_VN",
213
+ "Chinese": "zh_CN",
214
+ "Afrikaans": "af_ZA",
215
+ "Azerbaijani": "az_AZ",
216
+ "Bengali": "bn_IN",
217
+ "Persian": "fa_IR",
218
+ "Hebrew": "he_IL",
219
+ "Croatian": "hr_HR",
220
+ "Indonesian": "id_ID",
221
+ "Georgian": "ka_GE",
222
+ "Khmer": "km_KH",
223
+ "Macedonian": "mk_MK",
224
+ "Malayalam": "ml_IN",
225
+ "Mongolian": "mn_MN",
226
+ "Marathi": "mr_IN",
227
+ "Polish": "pl_PL",
228
+ "Pashto": "ps_AF",
229
+ "Portuguese": "pt_XX",
230
+ "Swedish": "sv_SE",
231
+ "Swahili": "sw_KE",
232
+ "Tamil": "ta_IN",
233
+ "Telugu": "te_IN",
234
+ "Thai": "th_TH",
235
+ "Tagalog": "tl_XX",
236
+ "Ukrainian": "uk_UA",
237
+ "Urdu": "ur_PK",
238
+ "Xhosa": "xh_ZA",
239
+ "Galician": "gl_ES",
240
+ "Slovene": "sl_SI",
241
+ }
242
+ language_names = sorted(list(lang_data.keys()))
243
+
244
+
245
+ def translate_text(text_to_translate, source_language_name, target_language_name):
246
+ """
247
+ Translates text using the loaded mBART model.
248
+ """
249
+ if translation_model is None or translation_tokenizer is None:
250
+ return "Translation model not loaded. Cannot perform translation."
251
+ if not text_to_translate:
252
+ return "Please enter text to translate."
253
+ if not source_language_name:
254
+ return "Please select a source language."
255
+ if not target_language_name:
256
+ return "Please select a target language."
257
+
258
+ try:
259
+ source_lang_code = lang_data[source_language_name]
260
+ target_lang_code = lang_data[target_language_name]
261
+
262
+ translation_tokenizer.src_lang = source_lang_code
263
+ encoded_text = translation_tokenizer(text_to_translate, return_tensors="pt").to(
264
+ "cuda"
265
+ ) # Move input to GPU
266
+ target_lang_id = translation_tokenizer.lang_code_to_id[target_lang_code]
267
+
268
+ # Generate translation on GPU
269
+ with torch.no_grad(): # Use no_grad for inference
270
+ generated_tokens = translation_model.generate(
271
+ **encoded_text, forced_bos_token_id=target_lang_id, max_length=200
272
+ )
273
+
274
+ translated_text = translation_tokenizer.batch_decode(
275
+ generated_tokens, skip_special_tokens=True
276
+ )
277
+
278
+ # Clean up GPU memory
279
+ del encoded_text, generated_tokens
280
+ torch.cuda.empty_cache()
281
+ gc.collect()
282
+
283
+ return translated_text[0]
284
+
285
+ except KeyError as e:
286
+ return f"Error: Language code not found for {e}. Check language mappings."
287
+ except Exception as e:
288
+ print(f"Translation error: {e}")
289
+ # Clean up GPU memory on error too
290
+ torch.cuda.empty_cache()
291
+ gc.collect()
292
+ return f"An error occurred during translation: {e}"
293
+
294
+
295
+ # --- Main Function Router (for image tasks) ---
296
+ # Note: Translation uses its own function directly
297
+ @spaces.GPU(duration=120) # Keep GPU decorator if needed for image tasks
298
  def main(*args):
299
  api_num = args[0]
300
  args = args[1:]
301
+ gc.collect() # Try to collect garbage before starting task
302
+ torch.cuda.empty_cache() # Clear cache before starting task
 
 
 
 
 
 
 
 
 
 
303
 
304
+ result = None
305
+ try:
306
+ if api_num == 1:
307
+ result = rmbg(*args)
308
+ elif api_num == 2:
309
+ result = outpaint(*args)
310
+ elif api_num == 3:
311
+ result = inpaint(*args)
312
+ # elif api_num == 4: # Keep commented out as in original
313
+ # return mask_generation(*args)
314
+ elif api_num == 5:
315
+ result = erase(*args)
316
+ else:
317
+ result = "Invalid API number."
318
+ except Exception as e:
319
+ print(f"Error in main task routing (api_num={api_num}): {e}")
320
+ result = f"An error occurred: {e}"
321
+ finally:
322
+ # Ensure memory cleanup happens even if there's an error
323
+ gc.collect()
324
+ torch.cuda.empty_cache()
325
+
326
+ return result
327
+
328
+
329
+ # --- Define Gradio Interfaces for Each Tab ---
330
 
331
+ # Image Task Tabs
332
  rmbg_tab = gr.Interface(
333
  fn=main,
334
  inputs=[
335
+ gr.Number(1, interactive=False, visible=False), # Hide API number
336
+ gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]),
337
+ gr.Text(label="Or Image URL (optional)"),
338
  ],
339
+ outputs=gr.Image(label="Output Image", type="pil"),
340
+ title="Remove Background",
341
+ description="Upload an image or provide a URL to remove its background.",
342
  api_name="rmbg",
343
+ # examples=[[1, "./assets/sample_rmbg.png", ""]], # Update example path if needed
344
  cache_examples=False,
 
345
  )
346
 
347
  outpaint_tab = gr.Interface(
348
  fn=main,
349
  inputs=[
350
+ gr.Number(2, interactive=False, visible=False),
351
+ gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]),
352
+ gr.Number(value=0, label="Padding Top (pixels)"),
353
+ gr.Number(value=0, label="Padding Bottom (pixels)"),
354
+ gr.Number(value=0, label="Padding Left (pixels)"),
355
+ gr.Number(value=0, label="Padding Right (pixels)"),
356
+ gr.Text(
357
+ label="Prompt (optional)",
358
+ info="Describe what to fill the extended area with",
359
+ ),
360
+ gr.Slider(
361
+ minimum=10, maximum=100, step=1, value=28, label="Inference Steps"
362
+ ), # Use slider for steps
363
+ gr.Slider(
364
+ minimum=1, maximum=100, step=1, value=50, label="Guidance Scale"
365
+ ), # Use slider for guidance
366
  ],
367
+ outputs=gr.Image(label="Outpainted Image", type="pil"),
368
+ title="Outpainting",
369
+ description="Extend an image by adding padding and filling the new area using a diffusion model.",
370
  api_name="outpainting",
371
+ # examples=[[2, "./assets/rocket.png", 100, 0, 0, 0, "", 28, 50]], # Update example path
372
  cache_examples=False,
373
  )
374
 
 
375
  inpaint_tab = gr.Interface(
376
  fn=main,
377
  inputs=[
378
+ gr.Number(3, interactive=False, visible=False),
379
+ gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]),
380
+ gr.Image(
381
+ label="Mask Image (White=Inpaint Area)",
382
+ type="pil",
383
+ sources=["upload", "clipboard"],
384
+ ),
385
+ gr.Text(
386
+ label="Prompt (optional)", info="Describe what to fill the masked area with"
387
+ ),
388
+ gr.Slider(minimum=10, maximum=100, step=1, value=28, label="Inference Steps"),
389
+ gr.Slider(minimum=1, maximum=100, step=1, value=50, label="Guidance Scale"),
390
  ],
391
+ outputs=gr.Image(label="Inpainted Image", type="pil"),
392
+ title="Inpainting",
393
+ description="Fill in the white areas of a mask applied to an image using a diffusion model.",
394
  api_name="inpaint",
395
+ # examples=[[3, "./assets/rocket.png", "./assets/Inpainting_mask.png", "", 28, 50]], # Update example paths
396
  cache_examples=False,
 
397
  )
398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  erase_tab = gr.Interface(
400
+ fn=main,
401
  inputs=[
402
+ gr.Number(5, interactive=False, visible=False),
403
+ gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]),
404
+ gr.Image(
405
+ label="Mask Image (White=Erase Area)",
406
+ type="pil",
407
+ sources=["upload", "clipboard"],
408
+ ),
 
 
 
 
409
  ],
410
+ outputs=gr.Image(label="Result Image", type="pil"),
411
+ title="Erase Object (LAMA)",
412
+ description="Erase objects from an image based on a mask using the LaMa inpainting model.",
413
  api_name="erase",
414
+ # examples=[[5, "./assets/rocket.png", "./assets/Inpainting_mask.png"]], # Update example paths
415
  cache_examples=False,
416
  )
417
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
+ # --- Define Translation Tab using gr.Blocks ---
420
+ with gr.Blocks() as translation_tab:
421
+ gr.Markdown(
422
+ """
423
+ ## Multilingual Translation (mBART-50)
424
+ Translate text between 50 different languages.
425
+ Select the source and target languages, enter your text, and click Translate.
426
+ """
427
+ )
428
+ with gr.Row():
429
+ with gr.Column(scale=1):
430
+ source_lang_dropdown = gr.Dropdown(
431
+ label="Source Language",
432
+ choices=language_names,
433
+ info="Select the language of your input text.",
434
+ )
435
+ target_lang_dropdown = gr.Dropdown(
436
+ label="Target Language",
437
+ choices=language_names,
438
+ info="Select the language you want to translate to.",
439
+ )
440
+ with gr.Column(scale=2):
441
+ input_textbox = gr.Textbox(
442
+ label="Text to Translate",
443
+ lines=6, # Increased lines
444
+ placeholder="Enter text here...",
445
+ )
446
+ translate_button = gr.Button(
447
+ "Translate", variant="primary"
448
+ ) # Added variant
449
+ output_textbox = gr.Textbox(
450
+ label="Translated Text",
451
+ lines=6, # Increased lines
452
+ interactive=False, # Make output read-only
453
+ )
454
+
455
+ # Connect Components to the translation function directly
456
+ translate_button.click(
457
+ fn=translate_text,
458
+ inputs=[input_textbox, source_lang_dropdown, target_lang_dropdown],
459
+ outputs=output_textbox,
460
+ api_name="translate", # Add API name for the translation endpoint
461
+ )
462
+
463
+ # Add Translation Examples
464
+ gr.Examples(
465
+ examples=[
466
+ [
467
+ "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है",
468
+ "Hindi",
469
+ "French",
470
+ ],
471
+ [
472
+ "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا.",
473
+ "Arabic",
474
+ "English",
475
+ ],
476
+ [
477
+ "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie.",
478
+ "French",
479
+ "German",
480
+ ],
481
+ ["Hello world! How are you today?", "English", "Spanish"],
482
+ ["Guten Tag!", "German", "Japanese"],
483
+ ["これはテストです", "Japanese", "English"],
484
+ ],
485
+ inputs=[input_textbox, source_lang_dropdown, target_lang_dropdown],
486
+ outputs=output_textbox,
487
+ fn=translate_text,
488
+ cache_examples=False,
489
+ )
490
+
491
+ # --- Combine all tabs ---
492
  demo = gr.TabbedInterface(
493
  [
494
  rmbg_tab,
495
  outpaint_tab,
496
  inpaint_tab,
 
497
  erase_tab,
498
+ translation_tab, # Add the translation tab
499
+ # sam2_tab, # Keep commented out
500
  ],
501
  [
502
+ "Remove Background", # Tab title
503
+ "Outpainting", # Tab title
504
+ "Inpainting", # Tab title
505
+ "Erase (LAMA)", # Tab title
506
+ "Translate", # Tab title for translation
507
  # "sam2",
 
 
508
  ],
509
+ title="Image & Text Utilities (GPU)", # Updated title
510
  )
511
 
 
512
  demo.launch()