ghostsInTheMachine commited on
Commit
509862d
·
verified ·
1 Parent(s): 012b840

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -99
app.py CHANGED
@@ -1,15 +1,21 @@
1
- import os
2
- import cv2
3
- import numpy as np
4
- import torch
5
  import gradio as gr
 
6
  import spaces
7
- from gradio.themes.base import Base
8
- from gradio.themes.utils import colors, fonts, sizes
9
- from PIL import Image, ImageOps
10
  from transformers import AutoModelForImageSegmentation
11
  from torchvision import transforms
 
 
 
 
 
 
 
 
 
 
 
12
 
 
13
  class WhiteTheme(Base):
14
  def __init__(
15
  self,
@@ -59,89 +65,180 @@ class WhiteTheme(Base):
59
  shadow_drop="none"
60
  )
61
 
62
- torch.set_float32_matmul_precision('high')
63
- torch.jit.script = lambda f: f
64
-
65
  device = "cuda" if torch.cuda.is_available() else "cpu"
66
 
67
- def refine_foreground(image, mask, r=90):
68
- if mask.size != image.size:
69
- mask = mask.resize(image.size)
70
- image = np.array(image) / 255.0
71
- mask = np.array(mask) / 255.0
72
- estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
73
- image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
74
- return image_masked
75
-
76
- def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
77
- alpha = alpha[:, :, None]
78
- F, blur_B = FB_blur_fusion_foreground_estimator(
79
- image, image, image, alpha, r)
80
- return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
81
-
82
- def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
83
- if isinstance(image, Image.Image):
84
- image = np.array(image) / 255.0
85
- blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
86
- blurred_FA = cv2.blur(F * alpha, (r, r))
87
- blurred_F = blurred_FA / (blurred_alpha + 1e-5)
88
- blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
89
- blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
90
- F = blurred_F + alpha * (image - alpha * blurred_F - (1 - alpha) * blurred_B)
91
- F = np.clip(F, 0, 1)
92
- return F, blurred_B
93
-
94
- class ImagePreprocessor():
95
- def __init__(self, resolution=(1024, 1024)) -> None:
96
- self.transform_image = transforms.Compose([
97
- transforms.Resize(resolution),
98
- transforms.ToTensor(),
99
- transforms.Normalize([0.485, 0.456, 0.406],
100
- [0.229, 0.224, 0.225]),
101
- ])
102
-
103
- def proc(self, image: Image.Image) -> torch.Tensor:
104
- image = self.transform_image(image)
105
- return image
106
-
107
- # Load the model
108
- birefnet = AutoModelForImageSegmentation.from_pretrained(
109
- 'zhengpeng7/BiRefNet-matting', trust_remote_code=True)
110
  birefnet.to(device)
111
- birefnet.eval()
112
-
113
- def remove_background_wrapper(image):
114
- if image is None:
115
- raise gr.Error("Please upload an image.")
116
- image_ori = Image.fromarray(image).convert('RGB')
117
- foreground, background, pred_pil, reverse_mask = remove_background(image_ori)
118
- return foreground, background, pred_pil, reverse_mask
119
-
120
- @spaces.GPU
121
- def remove_background(image_ori):
122
- original_size = image_ori.size
123
- image_preprocessor = ImagePreprocessor(resolution=(1024, 1024))
124
- image_proc = image_preprocessor.proc(image_ori)
125
- image_proc = image_proc.unsqueeze(0)
126
-
127
- with torch.no_grad():
128
- preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
129
- pred = preds[0].squeeze()
130
-
131
- pred_pil = transforms.ToPILImage()(pred)
132
- pred_pil = pred_pil.resize(original_size, Image.BICUBIC)
133
-
134
- reverse_mask = ImageOps.invert(pred_pil)
135
-
136
- foreground = image_ori.copy()
137
- foreground.putalpha(pred_pil)
138
-
139
- background = image_ori.copy()
140
- background.putalpha(reverse_mask)
141
-
142
- torch.cuda.empty_cache()
143
-
144
- return foreground, background, pred_pil, reverse_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  # Custom CSS for styling
147
  custom_css = """
@@ -204,6 +301,7 @@ custom_css = """
204
  }
205
  """
206
 
 
207
  with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
208
  gr.HTML('''
209
  <div class="title-container">
@@ -213,7 +311,7 @@ with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
213
  </div>
214
  <script>
215
  (function() {
216
- const text = "image";
217
  const typedTextSpan = document.getElementById("typed-text");
218
  let charIndex = 0;
219
 
@@ -230,20 +328,57 @@ with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
230
  </script>
231
  ''')
232
 
233
- # Interface setup with input and output
234
  with gr.Row():
235
  with gr.Column():
236
- image_input = gr.Image(type="numpy", sources=['upload'], label="Upload Image")
237
- btn = gr.Button("Process Image", elem_id="submit-button")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  with gr.Column():
240
- output_foreground = gr.Image(type="pil", label="Foreground")
241
- output_background = gr.Image(type="pil", label="Background")
242
- output_foreground_mask = gr.Image(type="pil", label="Foreground Mask")
243
- output_background_mask = gr.Image(type="pil", label="Background Mask")
244
-
245
- # Link the button to the processing function
246
- btn.click(fn=remove_background_wrapper, inputs=image_input, outputs=[
247
- output_foreground, output_background, output_foreground_mask, output_background_mask])
 
 
 
 
 
 
 
 
 
 
 
248
 
 
249
  demo.launch(debug=True)
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
  import spaces
 
 
 
4
  from transformers import AutoModelForImageSegmentation
5
  from torchvision import transforms
6
+ import moviepy.editor as mp
7
+ from PIL import Image
8
+ import numpy as np
9
+ import tempfile
10
+ import time
11
+ import os
12
+ import shutil
13
+ import ffmpeg
14
+ from concurrent.futures import ThreadPoolExecutor
15
+ from gradio.themes.base import Base
16
+ from gradio.themes.utils import colors, fonts
17
 
18
+ # Custom Theme Definition
19
  class WhiteTheme(Base):
20
  def __init__(
21
  self,
 
65
  shadow_drop="none"
66
  )
67
 
68
+ # Set precision and device
69
+ torch.set_float32_matmul_precision("medium")
 
70
  device = "cuda" if torch.cuda.is_available() else "cpu"
71
 
72
+ # Load models
73
+ print("Loading models...")
74
+ birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  birefnet.to(device)
76
+ birefnet_lite = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet_lite", trust_remote_code=True)
77
+ birefnet_lite.to(device)
78
+ print("Models loaded successfully!")
79
+
80
+ # Image transformation
81
+ transform_image = transforms.Compose([
82
+ transforms.Resize((1024, 1024)),
83
+ transforms.ToTensor(),
84
+ transforms.Normalize([0.485, 0.456, 0.406],
85
+ [0.229, 0.224, 0.225]),
86
+ ])
87
+
88
+ def process_frame(frame, fast_mode=True):
89
+ """
90
+ Process a single frame through the BiRefNet model.
91
+ Maintains original resolution throughout processing.
92
+ Returns a PIL Image with alpha channel.
93
+ """
94
+ try:
95
+ # Preserve original resolution for final output
96
+ image_ori = Image.fromarray(frame).convert('RGB')
97
+ original_size = image_ori.size
98
+
99
+ # Transform for model input while maintaining aspect ratio
100
+ input_images = transform_image(image_ori).unsqueeze(0).to(device)
101
+
102
+ # Select model based on mode
103
+ model = birefnet_lite if fast_mode else birefnet
104
+
105
+ with torch.no_grad():
106
+ preds = model(input_images)[-1].sigmoid().cpu()
107
+ pred = preds[0].squeeze()
108
+
109
+ # Resize mask back to original resolution
110
+ pred_pil = transforms.ToPILImage()(pred)
111
+ pred_pil = pred_pil.resize(original_size, Image.BICUBIC)
112
+
113
+ # Create foreground with transparency
114
+ foreground = image_ori.copy()
115
+ foreground.putalpha(pred_pil)
116
+
117
+ return foreground
118
+ except Exception as e:
119
+ print(f"Error processing frame: {e}")
120
+ return None
121
+
122
+ @spaces.GPU(duration=300) # 5-minute duration for processing
123
+ def process_video(video_path, fps=0, fast_mode=True, max_workers=6):
124
+ """
125
+ Process video to create transparent MOV file using ProRes 4444.
126
+ Maintains original resolution and framerate if fps=0.
127
+ """
128
+ temp_dir = None
129
+ try:
130
+ start_time = time.time()
131
+ video = mp.VideoFileClip(video_path)
132
+
133
+ # Use original video FPS if not specified
134
+ if fps == 0:
135
+ fps = video.fps
136
+
137
+ frames = list(video.iter_frames(fps=fps))
138
+ total_frames = len(frames)
139
+
140
+ print(f"Processing {total_frames} frames at {fps} FPS...")
141
+
142
+ # Create temporary directory for PNG sequence
143
+ temp_dir = tempfile.mkdtemp()
144
+ png_dir = os.path.join(temp_dir, "frames")
145
+ os.makedirs(png_dir, exist_ok=True)
146
+
147
+ # Prepare to collect processed frames for live preview
148
+ processed_frames = []
149
+
150
+ # Process frames with parallel execution
151
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
152
+ futures = [executor.submit(process_frame, frame, fast_mode) for frame in frames]
153
+ for i, future in enumerate(futures):
154
+ try:
155
+ result = future.result()
156
+ if result:
157
+ # Save frame as PNG with transparency
158
+ frame_path = os.path.join(png_dir, f"frame_{i:06d}.png")
159
+ result.save(frame_path, "PNG")
160
+
161
+ # Collect processed frames for live preview
162
+ processed_frames.append(np.array(result))
163
+
164
+ # Update live preview
165
+ elapsed_time = time.time() - start_time
166
+ yield processed_frames[-1], None, None, None, f"Processing frame {i+1}/{total_frames}... Elapsed time: {elapsed_time:.2f} seconds"
167
+
168
+ if (i + 1) % 10 == 0:
169
+ print(f"Processed {i+1}/{total_frames} frames")
170
+ except Exception as e:
171
+ print(f"Error processing frame {i+1}: {e}")
172
+
173
+ print("Creating output files...")
174
+ # Create permanent output directory
175
+ output_dir = os.path.join(os.path.dirname(video_path), "output")
176
+ os.makedirs(output_dir, exist_ok=True)
177
+
178
+ # Create ZIP file of PNG sequence
179
+ zip_filename = f"frames_{int(time.time())}.zip"
180
+ zip_path = os.path.join(output_dir, zip_filename)
181
+ shutil.make_archive(zip_path[:-4], 'zip', png_dir)
182
+
183
+ # Create MOV file with ProRes 4444
184
+ print("Creating ProRes 4444 MOV...")
185
+ mov_filename = f"video_{int(time.time())}.mov"
186
+ mov_path = os.path.join(output_dir, mov_filename)
187
+
188
+ try:
189
+ # FFmpeg settings for high-quality ProRes 4444
190
+ stream = ffmpeg.input(
191
+ os.path.join(png_dir, 'frame_%06d.png'),
192
+ pattern_type='sequence',
193
+ framerate=fps
194
+ )
195
+
196
+ # ProRes 4444 settings for maximum quality with alpha
197
+ stream = ffmpeg.output(
198
+ stream,
199
+ mov_path,
200
+ vcodec='prores_ks', # ProRes codec
201
+ pix_fmt='yuva444p10le', # 10-bit 4:4:4:4 pixel format with alpha
202
+ profile='4444', # ProRes 4444 profile for alpha support
203
+ alpha_bits=16, # Maximum alpha bit depth
204
+ qscale=1, # Highest quality setting
205
+ vendor='ap10', # Standard ProRes vendor tag
206
+ bits_per_mb=8000, # High bitrate for quality
207
+ threads=max_workers # Parallel processing
208
+ )
209
+
210
+ # Run FFmpeg command
211
+ ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True)
212
+ print("MOV video created successfully!")
213
+
214
+ except ffmpeg.Error as e:
215
+ print(f"Error creating MOV video: {e.stderr.decode() if e.stderr else str(e)}")
216
+ mov_path = None
217
+
218
+ print("Processing complete!")
219
+ # Yield the final outputs
220
+ yield None, zip_path, mov_path, None, f"Processing complete! Total time: {time.time() - start_time:.2f} seconds"
221
+
222
+ except Exception as e:
223
+ print(f"Error: {e}")
224
+ yield None, None, None, None, f"Error processing video: {e}"
225
+ finally:
226
+ # Clean up temporary directory
227
+ if temp_dir and os.path.exists(temp_dir):
228
+ try:
229
+ shutil.rmtree(temp_dir)
230
+ except Exception as e:
231
+ print(f"Error cleaning up temp directory: {e}")
232
+
233
+ @spaces.GPU(duration=300) # Match process_video duration
234
+ def process_wrapper(video, fps=0, fast_mode=True, max_workers=6):
235
+ if video is None:
236
+ raise gr.Error("Please upload a video.")
237
+ try:
238
+ for outputs in process_video(video, fps, fast_mode, max_workers):
239
+ yield outputs
240
+ except Exception as e:
241
+ raise gr.Error(f"Error processing video: {str(e)}")
242
 
243
  # Custom CSS for styling
244
  custom_css = """
 
301
  }
302
  """
303
 
304
+ # Gradio Interface
305
  with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
306
  gr.HTML('''
307
  <div class="title-container">
 
311
  </div>
312
  <script>
313
  (function() {
314
+ const text = "video";
315
  const typedTextSpan = document.getElementById("typed-text");
316
  let charIndex = 0;
317
 
 
328
  </script>
329
  ''')
330
 
 
331
  with gr.Row():
332
  with gr.Column():
333
+ video_input = gr.Video(
334
+ label="Upload Video",
335
+ interactive=True,
336
+ show_label=True,
337
+ height=360,
338
+ width=640
339
+ )
340
+ with gr.Row():
341
+ fps_slider = gr.Slider(
342
+ minimum=0,
343
+ maximum=60,
344
+ step=1,
345
+ value=0,
346
+ label="Output FPS (0 will inherit the original fps value)",
347
+ )
348
+ fast_mode_checkbox = gr.Checkbox(
349
+ label="Fast Mode (Use BiRefNet_lite)",
350
+ value=True
351
+ )
352
+ max_workers_slider = gr.Slider(
353
+ minimum=1,
354
+ maximum=32,
355
+ step=1,
356
+ value=6,
357
+ label="Max Workers",
358
+ info="Determines how many frames to process in parallel"
359
+ )
360
+ btn = gr.Button("Process Video", elem_id="submit-button")
361
 
362
  with gr.Column():
363
+ preview_image = gr.Image(label="Live Preview", show_label=True)
364
+ output_foreground_zip = gr.File(label="Download PNG Sequence (ZIP)")
365
+ output_foreground_video = gr.File(label="Download Video (ProRes 4444 MOV with transparency)")
366
+ output_background = gr.Video(label="Background (Coming Soon)")
367
+ time_textbox = gr.Textbox(label="Status", interactive=False)
368
+
369
+ gr.Markdown("""
370
+ ### Output Information
371
+ - MOV file uses ProRes 4444 codec for professional-grade alpha channel
372
+ - Original resolution and framerate are maintained
373
+ - PNG sequence provided for maximum compatibility
374
+ """)
375
+
376
+ btn.click(
377
+ fn=process_wrapper,
378
+ inputs=[video_input, fps_slider, fast_mode_checkbox, max_workers_slider],
379
+ outputs=[preview_image, output_foreground_zip, output_foreground_video,
380
+ output_background, time_textbox]
381
+ )
382
 
383
+ if __name__ == "__main__":
384
  demo.launch(debug=True)