Adeal1 commited on
Commit
fb6dc64
1 Parent(s): 8e6e3e0

Add initial project configuration and dependencies

Browse files
Files changed (3) hide show
  1. LICENCE +7 -0
  2. app.py +538 -0
  3. requirements.txt +19 -0
LICENCE ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Copyright 2024 Walid Opeyemi Adebayo
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
+
5
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6
+
7
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
app.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from loadimg import load_img
3
+ import spaces
4
+ from transformers import AutoModelForImageSegmentation
5
+ import torch
6
+ from torchvision import transforms
7
+ import moviepy.editor as mp
8
+ from pydub import AudioSegment
9
+ from PIL import Image
10
+ import numpy as np
11
+ import os
12
+ import tempfile
13
+ import uuid
14
+ import time
15
+ from concurrent.futures import ThreadPoolExecutor
16
+ from PIL import Image, ImageSequence
17
+ import base64
18
+ import io
19
+ import numpy as np
20
+ import tempfile
21
+ from gradio_imageslider import ImageSlider
22
+
23
+ torch.set_float32_matmul_precision(["high", "highest"][0])
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ # Maximum image size
27
+ Image.MAX_IMAGE_PIXELS = None
28
+
29
+ # Load both BiRefNet models
30
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
31
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
32
+ )
33
+ birefnet.to(device)
34
+ birefnet_lite = AutoModelForImageSegmentation.from_pretrained(
35
+ "ZhengPeng7/BiRefNet_lite", trust_remote_code=True
36
+ )
37
+ birefnet_lite.to(device)
38
+
39
+ transform_image = transforms.Compose(
40
+ [
41
+ transforms.Resize((1024, 1024)),
42
+ transforms.ToTensor(),
43
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
44
+ ]
45
+ )
46
+
47
+ # Video processing
48
+
49
+
50
+ # Function to process a single frame
51
+ def process_frame(
52
+ frame, bg_type, bg, fast_mode, bg_frame_index, background_frames, color
53
+ ):
54
+ try:
55
+ pil_image = Image.fromarray(frame)
56
+ if bg_type == "Color":
57
+ processed_image = process(pil_image, color, fast_mode)
58
+ elif bg_type == "Image":
59
+ processed_image = process(pil_image, bg, fast_mode)
60
+ elif bg_type == "Video":
61
+ background_frame = background_frames[
62
+ bg_frame_index
63
+ ] # Access the correct background frame
64
+ bg_frame_index += 1
65
+ background_image = Image.fromarray(background_frame)
66
+ processed_image = process(pil_image, background_image, fast_mode)
67
+ else:
68
+ processed_image = (
69
+ pil_image # Default to original image if no background is selected
70
+ )
71
+ return np.array(processed_image), bg_frame_index
72
+ except Exception as e:
73
+ print(f"Error processing frame: {e}")
74
+ return frame, bg_frame_index
75
+
76
+
77
+ @spaces.GPU
78
+ def remove_bg_video(
79
+ vid,
80
+ bg_type="Color",
81
+ bg_image=None,
82
+ bg_video=None,
83
+ color="#00FF00",
84
+ fps=0,
85
+ video_handling="slow_down",
86
+ fast_mode=True,
87
+ max_workers=6,
88
+ ):
89
+ try:
90
+ start_time = time.time() # Start the timer
91
+ video = mp.VideoFileClip(vid)
92
+ if fps == 0:
93
+ fps = video.fps
94
+
95
+ audio = video.audio
96
+ frames = list(video.iter_frames(fps=fps))
97
+
98
+ processed_frames = []
99
+ yield gr.update(visible=True), gr.update(
100
+ visible=False
101
+ ), f"Processing started... Elapsed time: 0 seconds"
102
+
103
+ if bg_type == "Video":
104
+ background_video = mp.VideoFileClip(bg_video)
105
+ if background_video.duration < video.duration:
106
+ if video_handling == "slow_down":
107
+ background_video = background_video.fx(
108
+ mp.vfx.speedx, factor=video.duration / background_video.duration
109
+ )
110
+ else: # video_handling == "loop"
111
+ background_video = mp.concatenate_videoclips(
112
+ [background_video]
113
+ * int(video.duration / background_video.duration + 1)
114
+ )
115
+ background_frames = list(background_video.iter_frames(fps=fps))
116
+ else:
117
+ background_frames = None
118
+
119
+ bg_frame_index = 0 # Initialize background frame index
120
+
121
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
122
+ # Pass bg_frame_index as part of the function arguments
123
+ futures = [
124
+ executor.submit(
125
+ process_frame,
126
+ frames[i],
127
+ bg_type,
128
+ bg_image,
129
+ fast_mode,
130
+ bg_frame_index + i,
131
+ background_frames,
132
+ color,
133
+ )
134
+ for i in range(len(frames))
135
+ ]
136
+ for i, future in enumerate(futures):
137
+ result, _ = future.result() # No need to update bg_frame_index here
138
+ processed_frames.append(result)
139
+ elapsed_time = time.time() - start_time
140
+ yield result, None, f"Processing frame {i+1}/{len(frames)}... Elapsed time: {elapsed_time:.2f} seconds"
141
+
142
+ processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
143
+ processed_video = processed_video.set_audio(audio)
144
+
145
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
146
+ temp_filepath = temp_file.name
147
+ processed_video.write_videofile(temp_filepath, codec="libx264")
148
+
149
+ elapsed_time = time.time() - start_time
150
+ yield gr.update(visible=False), gr.update(
151
+ visible=True
152
+ ), f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds"
153
+ yield processed_frames[
154
+ -1
155
+ ], temp_filepath, f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds"
156
+
157
+ except Exception as e:
158
+ print(f"Error: {e}")
159
+ elapsed_time = time.time() - start_time
160
+ yield gr.update(visible=False), gr.update(
161
+ visible=True
162
+ ), f"Error processing video: {e}. Elapsed time: {elapsed_time:.2f} seconds"
163
+ yield None, f"Error processing video: {e}", f"Error processing video: {e}. Elapsed time: {elapsed_time:.2f} seconds"
164
+
165
+
166
+ def process(image, bg, fast_mode=False):
167
+ image_size = image.size
168
+ input_images = transform_image(image).unsqueeze(0).to(device)
169
+ model = birefnet_lite if fast_mode else birefnet
170
+
171
+ with torch.no_grad():
172
+ preds = model(input_images)[-1].sigmoid().cpu()
173
+ pred = preds[0].squeeze()
174
+ pred_pil = transforms.ToPILImage()(pred)
175
+ mask = pred_pil.resize(image_size)
176
+
177
+ if isinstance(bg, str) and bg.startswith("#"):
178
+ color_rgb = tuple(int(bg[i : i + 2], 16) for i in (1, 3, 5))
179
+ background = Image.new("RGBA", image_size, color_rgb + (255,))
180
+ elif isinstance(bg, Image.Image):
181
+ background = bg.convert("RGBA").resize(image_size)
182
+ else:
183
+ background = Image.open(bg).convert("RGBA").resize(image_size)
184
+
185
+ image = Image.composite(image, background, mask)
186
+ return image
187
+
188
+
189
+ # Image processing
190
+
191
+ # Function to remove background from an image
192
+ def remove_bg_fn(image):
193
+ im = load_img(image, output_type="pil")
194
+ im = im.convert("RGB")
195
+ origin = im.copy()
196
+
197
+ if im.format == "GIF":
198
+ frames = []
199
+ for frame in ImageSequence.Iterator(im):
200
+ frame = frame.convert("RGBA")
201
+ processed_frame = process_image(frame)
202
+ frames.append(processed_frame)
203
+ processed_image = frames[0]
204
+ processed_image.save(
205
+ io.BytesIO(),
206
+ format="GIF",
207
+ save_all=True,
208
+ append_images=frames[1:],
209
+ loop=0,
210
+ )
211
+ else:
212
+ processed_image = process_image(im)
213
+
214
+ return (processed_image, origin)
215
+
216
+ @spaces.GPU
217
+ def process_image(image):
218
+ image_size = image.size
219
+ input_images = transform_image(image).unsqueeze(0).to(device)
220
+
221
+ # Prediction
222
+ with torch.no_grad():
223
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
224
+ pred = preds[0].squeeze()
225
+ pred_pil = transforms.ToPILImage()(pred)
226
+ mask = pred_pil.resize(image_size)
227
+ image.putalpha(mask)
228
+ return image
229
+
230
+
231
+
232
+
233
+ # Function to apply background to an image
234
+ @spaces.GPU
235
+ def apply_background(image, background):
236
+ if background.mode != "RGBA":
237
+ background = background.convert("RGBA")
238
+ image = image.convert("RGBA")
239
+ combined = Image.alpha_composite(background, image)
240
+ return combined
241
+
242
+
243
+ # Function to convert hex color to RGBA
244
+ def hex_to_rgba(hex_color):
245
+ hex_color = hex_color.lstrip("#")
246
+ lv = len(hex_color)
247
+ return tuple(int(hex_color[i : i + lv // 3], 16) for i in range(0, lv, lv // 3)) + (
248
+ 255,
249
+ )
250
+
251
+
252
+ def apply_bg_image(image, background_file=None, background_color=None, bg_type="Color"):
253
+ try:
254
+ image_data = image.read()
255
+ input_image = Image.open(io.BytesIO(image_data))
256
+ origin = input_image.copy()
257
+
258
+ color_profile = input_image.info.get("icc_profile")
259
+
260
+ if background_file is not None:
261
+ background_image = Image.open(io.BytesIO(background_file.read()))
262
+ else:
263
+ background_image = None
264
+
265
+ if bg_type == "Color":
266
+ background_image = Image.new("RGBA", input_image.size, hex_to_rgba(background_color))
267
+ elif bg_type == "Image" and background_image is not None:
268
+ if background_image.size != input_image.size:
269
+ background_image = background_image.resize(input_image.size)
270
+
271
+ if input_image.format == "GIF":
272
+ frames = []
273
+ for frame in ImageSequence.Iterator(input_image):
274
+ frame = frame.convert("RGBA")
275
+ output_frame = apply_background(frame, background_image)
276
+ frames.append(output_frame)
277
+
278
+ output_image = io.BytesIO()
279
+ frames[0].save(
280
+ output_image,
281
+ format="GIF",
282
+ save_all=True,
283
+ append_images=frames[1:],
284
+ loop=0,
285
+ icc_profile=color_profile,
286
+ )
287
+ output_image_base64 = base64.b64encode(output_image.getvalue()).decode("utf-8")
288
+ else:
289
+ output_image = apply_background(input_image, background_image)
290
+ buffered = io.BytesIO()
291
+ output_image.save(buffered, format="PNG", optimize=True, icc_profile=color_profile)
292
+ output_image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
293
+
294
+ output_image_data = base64.b64decode(output_image_base64)
295
+ return (Image.open(io.BytesIO(output_image_data)), origin)
296
+ except Exception as e:
297
+ return str(e)
298
+
299
+
300
+
301
+ # Gradio interface
302
+ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
303
+ gr.Markdown("# Image and Video Background Remover & Changer\n\nRemove or apply background to images and videos.")
304
+
305
+ with gr.Tab("Remove Image Background"):
306
+ with gr.Row():
307
+ image_input = gr.Image(label="Upload Image", interactive=True)
308
+ slider = ImageSlider(label="Processed Image", type="pil")
309
+
310
+ remove_button = gr.Button("Remove Image Background", interactive=True)
311
+
312
+ examples = gr.Examples(
313
+ [
314
+ load_img(
315
+ "https://images.rawpixel.com/image_800/cHJpdmF0ZS9sci9pbWFnZXMvd2Vic2l0ZS8yMDIzLTA4L3Jhd3BpeGVsX29mZmljZV8yX3Bob3RvX29mX2FfbGlvbl9pc29sYXRlZF9vbl9jb2xvcl9iYWNrZ3JvdW5kXzJhNzgwMjM1LWRlYTgtNDMyOS04OWVjLTY3ZWMwNjcxZDhiMV8xLmpwZw.jpg",
316
+ output_type="pil",
317
+ )
318
+ ],
319
+ inputs=image_input,
320
+ fn=remove_bg_fn,
321
+ outputs=slider,
322
+ cache_examples=True,
323
+ cache_mode="eager",
324
+ )
325
+
326
+ remove_button.click(remove_bg_fn, inputs=image_input, outputs=slider)
327
+
328
+ with gr.Tab("Apply Background to Image"):
329
+
330
+ with gr.Row():
331
+ image_input = gr.Image(label="Upload Image", interactive=True)
332
+ slider = ImageSlider(label="Processed Image", type="pil")
333
+
334
+ apply_button = gr.Button("Apply Background", interactive=True)
335
+
336
+ with gr.Row():
337
+ bg_type = gr.Radio(
338
+ ["Color", "Image"],
339
+ label="Background Type",
340
+ value="Color",
341
+ interactive=True,
342
+ )
343
+ color_picker = gr.ColorPicker(
344
+ label="Background Color",
345
+ value="#00FF00",
346
+ visible=True,
347
+ interactive=True,
348
+ )
349
+ bg_image = gr.Image(
350
+ label="Background Image",
351
+ type="filepath",
352
+ visible=False,
353
+ interactive=True,
354
+ )
355
+
356
+ def update_visibility(bg_type):
357
+ if bg_type == "Color":
358
+ return (
359
+ gr.update(visible=True),
360
+ gr.update(visible=False),
361
+ )
362
+ elif bg_type == "Image":
363
+ return (
364
+ gr.update(visible=False),
365
+ gr.update(visible=True),
366
+ )
367
+ else:
368
+ return (
369
+ gr.update(visible=False),
370
+ gr.update(visible=False),
371
+ )
372
+
373
+ bg_type.change(
374
+ update_visibility,
375
+ inputs=bg_type,
376
+ outputs=[color_picker, bg_image],
377
+ )
378
+
379
+ examples = gr.Examples(
380
+ [
381
+ ["https://pngimg.com/d/mario_PNG125.png", None, "#0cfa38", "Color"],
382
+ [
383
+ "https://pngimg.com/d/mario_PNG125.png",
384
+ "https://cdn.photoroom.com/v2/image-cache?path=gs://background-7ef44.appspot.com/backgrounds_v3/black/47_-_black.jpg",
385
+ None,
386
+ "Image",
387
+ ],
388
+ ],
389
+ inputs=[image_input, bg_image, color_picker, bg_type],
390
+ fn=apply_bg_image,
391
+ outputs=slider,
392
+ cache_examples=True,
393
+ cache_mode="eager",
394
+ )
395
+
396
+ apply_button.click(
397
+ apply_bg_image,
398
+ inputs=[image_input, bg_image, color_picker, bg_type],
399
+ outputs= slider,
400
+ )
401
+
402
+
403
+ with gr.Tab("Remove Video Background"):
404
+ with gr.Row():
405
+ in_video = gr.Video(label="Input Video", interactive=True)
406
+ stream_image = gr.Image(label="Streaming Output", visible=False)
407
+ out_video = gr.Video(label="Final Output Video")
408
+
409
+ submit_button = gr.Button("Change Background", interactive=True)
410
+
411
+ with gr.Row():
412
+ fps_slider = gr.Slider(
413
+ minimum=0,
414
+ maximum=60,
415
+ step=1,
416
+ value=0,
417
+ label="Output FPS (0 will inherit the original fps value)",
418
+ interactive=True,
419
+ )
420
+ bg_type = gr.Radio(
421
+ ["Color", "Image", "Video"],
422
+ label="Background Type",
423
+ value="Color",
424
+ interactive=True,
425
+ )
426
+ color_picker = gr.ColorPicker(
427
+ label="Background Color",
428
+ value="#00FF00",
429
+ visible=True,
430
+ interactive=True,
431
+ )
432
+ bg_image = gr.Image(
433
+ label="Background Image",
434
+ type="filepath",
435
+ visible=False,
436
+ interactive=True,
437
+ )
438
+ bg_video = gr.Video(
439
+ label="Background Video", visible=False, interactive=True
440
+ )
441
+
442
+ with gr.Column(visible=False) as video_handling_options:
443
+ video_handling_radio = gr.Radio(
444
+ ["slow_down", "loop"],
445
+ label="Video Handling",
446
+ value="slow_down",
447
+ interactive=True,
448
+ )
449
+
450
+ fast_mode_checkbox = gr.Checkbox(
451
+ label="Fast Mode (Use BiRefNet_lite)", value=True, interactive=True
452
+ )
453
+ max_workers_slider = gr.Slider(
454
+ minimum=1,
455
+ maximum=32,
456
+ step=1,
457
+ value=6,
458
+ label="Max Workers",
459
+ info="Determines how many frames to process in parallel",
460
+ interactive=True,
461
+ )
462
+
463
+ time_textbox = gr.Textbox(label="Time Elapsed", interactive=False)
464
+
465
+ def update_visibility(bg_type):
466
+ if bg_type == "Color":
467
+ return (
468
+ gr.update(visible=True),
469
+ gr.update(visible=False),
470
+ gr.update(visible=False),
471
+ )
472
+ elif bg_type == "Image":
473
+ return (
474
+ gr.update(visible=False),
475
+ gr.update(visible=True),
476
+ gr.update(visible=False),
477
+ gr.update(visible=False),
478
+ )
479
+ elif bg_type == "Video":
480
+ return (
481
+ gr.update(visible=False),
482
+ gr.update(visible=False),
483
+ gr.update(visible=True),
484
+ )
485
+ else:
486
+ return (
487
+ gr.update(visible=False),
488
+ gr.update(visible=False),
489
+ gr.update(visible=False),
490
+ )
491
+
492
+ bg_type.change(
493
+ update_visibility,
494
+ inputs=bg_type,
495
+ outputs=[color_picker, bg_image, bg_video, video_handling_options],
496
+ )
497
+
498
+ examples = gr.Examples(
499
+ [
500
+ [
501
+ "https://www.w3schools.com/html/mov_bbb.mp4",
502
+ "Video",
503
+ None,
504
+ "https://www.w3schools.com/howto/rain.mp4",
505
+ ],
506
+ [
507
+ "https://www.w3schools.com/html/mov_bbb.mp4",
508
+ "Image",
509
+ "https://cdn.photoroom.com/v2/image-cache?path=gs://background-7ef44.appspot.com/backgrounds_v3/black/47_-_black.jpg",
510
+ None,
511
+ ],
512
+ ["https://www.w3schools.com/html/mov_bbb.mp4", "Color", None, None],
513
+ ],
514
+ inputs=[in_video, bg_type, bg_image, bg_video],
515
+ outputs=[stream_image, out_video, time_textbox],
516
+ fn=remove_bg_video,
517
+ cache_examples=True,
518
+ cache_mode="eager",
519
+ )
520
+
521
+ submit_button.click(
522
+ remove_bg_video,
523
+ inputs=[
524
+ in_video,
525
+ bg_type,
526
+ bg_image,
527
+ bg_video,
528
+ color_picker,
529
+ fps_slider,
530
+ video_handling_radio,
531
+ fast_mode_checkbox,
532
+ max_workers_slider,
533
+ ],
534
+ outputs=[stream_image, out_video, time_textbox],
535
+ )
536
+
537
+ if __name__ == "__main__":
538
+ demo.launch(show_error=True)
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ accelerate
3
+ opencv-python
4
+ spaces
5
+ pillow
6
+ numpy
7
+ timm
8
+ kornia
9
+ prettytable
10
+ typing
11
+ scikit-image
12
+ huggingface_hub
13
+ transformers>=4.39.1
14
+ gradio
15
+ schedule
16
+ loadimg>=0.1.1
17
+ moviepy==1.0.3
18
+ pydub
19
+ gradio_imageslider