ghostsInTheMachine commited on
Commit
4316c61
·
verified ·
1 Parent(s): 509862d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -234
app.py CHANGED
@@ -1,21 +1,15 @@
1
- import gradio as gr
 
 
2
  import torch
 
3
  import spaces
 
 
 
4
  from transformers import AutoModelForImageSegmentation
5
  from torchvision import transforms
6
- import moviepy.editor as mp
7
- from PIL import Image
8
- import numpy as np
9
- import tempfile
10
- import time
11
- import os
12
- import shutil
13
- import ffmpeg
14
- from concurrent.futures import ThreadPoolExecutor
15
- from gradio.themes.base import Base
16
- from gradio.themes.utils import colors, fonts
17
 
18
- # Custom Theme Definition
19
  class WhiteTheme(Base):
20
  def __init__(
21
  self,
@@ -65,180 +59,89 @@ class WhiteTheme(Base):
65
  shadow_drop="none"
66
  )
67
 
68
- # Set precision and device
69
- torch.set_float32_matmul_precision("medium")
 
70
  device = "cuda" if torch.cuda.is_available() else "cpu"
71
 
72
- # Load models
73
- print("Loading models...")
74
- birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  birefnet.to(device)
76
- birefnet_lite = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet_lite", trust_remote_code=True)
77
- birefnet_lite.to(device)
78
- print("Models loaded successfully!")
79
-
80
- # Image transformation
81
- transform_image = transforms.Compose([
82
- transforms.Resize((1024, 1024)),
83
- transforms.ToTensor(),
84
- transforms.Normalize([0.485, 0.456, 0.406],
85
- [0.229, 0.224, 0.225]),
86
- ])
87
-
88
- def process_frame(frame, fast_mode=True):
89
- """
90
- Process a single frame through the BiRefNet model.
91
- Maintains original resolution throughout processing.
92
- Returns a PIL Image with alpha channel.
93
- """
94
- try:
95
- # Preserve original resolution for final output
96
- image_ori = Image.fromarray(frame).convert('RGB')
97
- original_size = image_ori.size
98
-
99
- # Transform for model input while maintaining aspect ratio
100
- input_images = transform_image(image_ori).unsqueeze(0).to(device)
101
-
102
- # Select model based on mode
103
- model = birefnet_lite if fast_mode else birefnet
104
-
105
- with torch.no_grad():
106
- preds = model(input_images)[-1].sigmoid().cpu()
107
- pred = preds[0].squeeze()
108
-
109
- # Resize mask back to original resolution
110
- pred_pil = transforms.ToPILImage()(pred)
111
- pred_pil = pred_pil.resize(original_size, Image.BICUBIC)
112
-
113
- # Create foreground with transparency
114
- foreground = image_ori.copy()
115
- foreground.putalpha(pred_pil)
116
-
117
- return foreground
118
- except Exception as e:
119
- print(f"Error processing frame: {e}")
120
- return None
121
-
122
- @spaces.GPU(duration=300) # 5-minute duration for processing
123
- def process_video(video_path, fps=0, fast_mode=True, max_workers=6):
124
- """
125
- Process video to create transparent MOV file using ProRes 4444.
126
- Maintains original resolution and framerate if fps=0.
127
- """
128
- temp_dir = None
129
- try:
130
- start_time = time.time()
131
- video = mp.VideoFileClip(video_path)
132
-
133
- # Use original video FPS if not specified
134
- if fps == 0:
135
- fps = video.fps
136
-
137
- frames = list(video.iter_frames(fps=fps))
138
- total_frames = len(frames)
139
-
140
- print(f"Processing {total_frames} frames at {fps} FPS...")
141
-
142
- # Create temporary directory for PNG sequence
143
- temp_dir = tempfile.mkdtemp()
144
- png_dir = os.path.join(temp_dir, "frames")
145
- os.makedirs(png_dir, exist_ok=True)
146
-
147
- # Prepare to collect processed frames for live preview
148
- processed_frames = []
149
-
150
- # Process frames with parallel execution
151
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
152
- futures = [executor.submit(process_frame, frame, fast_mode) for frame in frames]
153
- for i, future in enumerate(futures):
154
- try:
155
- result = future.result()
156
- if result:
157
- # Save frame as PNG with transparency
158
- frame_path = os.path.join(png_dir, f"frame_{i:06d}.png")
159
- result.save(frame_path, "PNG")
160
-
161
- # Collect processed frames for live preview
162
- processed_frames.append(np.array(result))
163
-
164
- # Update live preview
165
- elapsed_time = time.time() - start_time
166
- yield processed_frames[-1], None, None, None, f"Processing frame {i+1}/{total_frames}... Elapsed time: {elapsed_time:.2f} seconds"
167
-
168
- if (i + 1) % 10 == 0:
169
- print(f"Processed {i+1}/{total_frames} frames")
170
- except Exception as e:
171
- print(f"Error processing frame {i+1}: {e}")
172
-
173
- print("Creating output files...")
174
- # Create permanent output directory
175
- output_dir = os.path.join(os.path.dirname(video_path), "output")
176
- os.makedirs(output_dir, exist_ok=True)
177
-
178
- # Create ZIP file of PNG sequence
179
- zip_filename = f"frames_{int(time.time())}.zip"
180
- zip_path = os.path.join(output_dir, zip_filename)
181
- shutil.make_archive(zip_path[:-4], 'zip', png_dir)
182
-
183
- # Create MOV file with ProRes 4444
184
- print("Creating ProRes 4444 MOV...")
185
- mov_filename = f"video_{int(time.time())}.mov"
186
- mov_path = os.path.join(output_dir, mov_filename)
187
-
188
- try:
189
- # FFmpeg settings for high-quality ProRes 4444
190
- stream = ffmpeg.input(
191
- os.path.join(png_dir, 'frame_%06d.png'),
192
- pattern_type='sequence',
193
- framerate=fps
194
- )
195
-
196
- # ProRes 4444 settings for maximum quality with alpha
197
- stream = ffmpeg.output(
198
- stream,
199
- mov_path,
200
- vcodec='prores_ks', # ProRes codec
201
- pix_fmt='yuva444p10le', # 10-bit 4:4:4:4 pixel format with alpha
202
- profile='4444', # ProRes 4444 profile for alpha support
203
- alpha_bits=16, # Maximum alpha bit depth
204
- qscale=1, # Highest quality setting
205
- vendor='ap10', # Standard ProRes vendor tag
206
- bits_per_mb=8000, # High bitrate for quality
207
- threads=max_workers # Parallel processing
208
- )
209
-
210
- # Run FFmpeg command
211
- ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True)
212
- print("MOV video created successfully!")
213
-
214
- except ffmpeg.Error as e:
215
- print(f"Error creating MOV video: {e.stderr.decode() if e.stderr else str(e)}")
216
- mov_path = None
217
-
218
- print("Processing complete!")
219
- # Yield the final outputs
220
- yield None, zip_path, mov_path, None, f"Processing complete! Total time: {time.time() - start_time:.2f} seconds"
221
-
222
- except Exception as e:
223
- print(f"Error: {e}")
224
- yield None, None, None, None, f"Error processing video: {e}"
225
- finally:
226
- # Clean up temporary directory
227
- if temp_dir and os.path.exists(temp_dir):
228
- try:
229
- shutil.rmtree(temp_dir)
230
- except Exception as e:
231
- print(f"Error cleaning up temp directory: {e}")
232
-
233
- @spaces.GPU(duration=300) # Match process_video duration
234
- def process_wrapper(video, fps=0, fast_mode=True, max_workers=6):
235
- if video is None:
236
- raise gr.Error("Please upload a video.")
237
- try:
238
- for outputs in process_video(video, fps, fast_mode, max_workers):
239
- yield outputs
240
- except Exception as e:
241
- raise gr.Error(f"Error processing video: {str(e)}")
242
 
243
  # Custom CSS for styling
244
  custom_css = """
@@ -301,7 +204,6 @@ custom_css = """
301
  }
302
  """
303
 
304
- # Gradio Interface
305
  with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
306
  gr.HTML('''
307
  <div class="title-container">
@@ -311,7 +213,7 @@ with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
311
  </div>
312
  <script>
313
  (function() {
314
- const text = "video";
315
  const typedTextSpan = document.getElementById("typed-text");
316
  let charIndex = 0;
317
 
@@ -328,57 +230,20 @@ with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
328
  </script>
329
  ''')
330
 
 
331
  with gr.Row():
332
  with gr.Column():
333
- video_input = gr.Video(
334
- label="Upload Video",
335
- interactive=True,
336
- show_label=True,
337
- height=360,
338
- width=640
339
- )
340
- with gr.Row():
341
- fps_slider = gr.Slider(
342
- minimum=0,
343
- maximum=60,
344
- step=1,
345
- value=0,
346
- label="Output FPS (0 will inherit the original fps value)",
347
- )
348
- fast_mode_checkbox = gr.Checkbox(
349
- label="Fast Mode (Use BiRefNet_lite)",
350
- value=True
351
- )
352
- max_workers_slider = gr.Slider(
353
- minimum=1,
354
- maximum=32,
355
- step=1,
356
- value=6,
357
- label="Max Workers",
358
- info="Determines how many frames to process in parallel"
359
- )
360
- btn = gr.Button("Process Video", elem_id="submit-button")
361
 
362
  with gr.Column():
363
- preview_image = gr.Image(label="Live Preview", show_label=True)
364
- output_foreground_zip = gr.File(label="Download PNG Sequence (ZIP)")
365
- output_foreground_video = gr.File(label="Download Video (ProRes 4444 MOV with transparency)")
366
- output_background = gr.Video(label="Background (Coming Soon)")
367
- time_textbox = gr.Textbox(label="Status", interactive=False)
368
-
369
- gr.Markdown("""
370
- ### Output Information
371
- - MOV file uses ProRes 4444 codec for professional-grade alpha channel
372
- - Original resolution and framerate are maintained
373
- - PNG sequence provided for maximum compatibility
374
- """)
375
-
376
- btn.click(
377
- fn=process_wrapper,
378
- inputs=[video_input, fps_slider, fast_mode_checkbox, max_workers_slider],
379
- outputs=[preview_image, output_foreground_zip, output_foreground_video,
380
- output_background, time_textbox]
381
- )
382
 
383
- if __name__ == "__main__":
384
  demo.launch(debug=True)
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
  import torch
5
+ import gradio as gr
6
  import spaces
7
+ from gradio.themes.base import Base
8
+ from gradio.themes.utils import colors, fonts, sizes
9
+ from PIL import Image, ImageOps
10
  from transformers import AutoModelForImageSegmentation
11
  from torchvision import transforms
 
 
 
 
 
 
 
 
 
 
 
12
 
 
13
  class WhiteTheme(Base):
14
  def __init__(
15
  self,
 
59
  shadow_drop="none"
60
  )
61
 
62
+ torch.set_float32_matmul_precision('high')
63
+ torch.jit.script = lambda f: f
64
+
65
  device = "cuda" if torch.cuda.is_available() else "cpu"
66
 
67
+ def refine_foreground(image, mask, r=90):
68
+ if mask.size != image.size:
69
+ mask = mask.resize(image.size)
70
+ image = np.array(image) / 255.0
71
+ mask = np.array(mask) / 255.0
72
+ estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
73
+ image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
74
+ return image_masked
75
+
76
+ def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
77
+ alpha = alpha[:, :, None]
78
+ F, blur_B = FB_blur_fusion_foreground_estimator(
79
+ image, image, image, alpha, r)
80
+ return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
81
+
82
+ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
83
+ if isinstance(image, Image.Image):
84
+ image = np.array(image) / 255.0
85
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
86
+ blurred_FA = cv2.blur(F * alpha, (r, r))
87
+ blurred_F = blurred_FA / (blurred_alpha + 1e-5)
88
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
89
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
90
+ F = blurred_F + alpha * (image - alpha * blurred_F - (1 - alpha) * blurred_B)
91
+ F = np.clip(F, 0, 1)
92
+ return F, blurred_B
93
+
94
+ class ImagePreprocessor():
95
+ def __init__(self, resolution=(1024, 1024)) -> None:
96
+ self.transform_image = transforms.Compose([
97
+ transforms.Resize(resolution),
98
+ transforms.ToTensor(),
99
+ transforms.Normalize([0.485, 0.456, 0.406],
100
+ [0.229, 0.224, 0.225]),
101
+ ])
102
+
103
+ def proc(self, image: Image.Image) -> torch.Tensor:
104
+ image = self.transform_image(image)
105
+ return image
106
+
107
+ # Load the model
108
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
109
+ 'zhengpeng7/BiRefNet-matting', trust_remote_code=True)
110
  birefnet.to(device)
111
+ birefnet.eval()
112
+
113
+ def remove_background_wrapper(image):
114
+ if image is None:
115
+ raise gr.Error("Please upload an image.")
116
+ image_ori = Image.fromarray(image).convert('RGB')
117
+ foreground, background, pred_pil, reverse_mask = remove_background(image_ori)
118
+ return foreground, background, pred_pil, reverse_mask
119
+
120
+ @spaces.GPU
121
+ def remove_background(image_ori):
122
+ original_size = image_ori.size
123
+ image_preprocessor = ImagePreprocessor(resolution=(1024, 1024))
124
+ image_proc = image_preprocessor.proc(image_ori)
125
+ image_proc = image_proc.unsqueeze(0)
126
+
127
+ with torch.no_grad():
128
+ preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
129
+ pred = preds[0].squeeze()
130
+
131
+ pred_pil = transforms.ToPILImage()(pred)
132
+ pred_pil = pred_pil.resize(original_size, Image.BICUBIC)
133
+
134
+ reverse_mask = ImageOps.invert(pred_pil)
135
+
136
+ foreground = image_ori.copy()
137
+ foreground.putalpha(pred_pil)
138
+
139
+ background = image_ori.copy()
140
+ background.putalpha(reverse_mask)
141
+
142
+ torch.cuda.empty_cache()
143
+
144
+ return foreground, background, pred_pil, reverse_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  # Custom CSS for styling
147
  custom_css = """
 
204
  }
205
  """
206
 
 
207
  with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
208
  gr.HTML('''
209
  <div class="title-container">
 
213
  </div>
214
  <script>
215
  (function() {
216
+ const text = "image";
217
  const typedTextSpan = document.getElementById("typed-text");
218
  let charIndex = 0;
219
 
 
230
  </script>
231
  ''')
232
 
233
+ # Interface setup with input and output
234
  with gr.Row():
235
  with gr.Column():
236
+ image_input = gr.Image(type="numpy", sources=['upload'], label="Upload Image")
237
+ btn = gr.Button("Process Image", elem_id="submit-button")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  with gr.Column():
240
+ output_foreground = gr.Image(type="pil", label="Foreground")
241
+ output_background = gr.Image(type="pil", label="Background")
242
+ output_foreground_mask = gr.Image(type="pil", label="Foreground Mask")
243
+ output_background_mask = gr.Image(type="pil", label="Background Mask")
244
+
245
+ # Link the button to the processing function
246
+ btn.click(fn=remove_background_wrapper, inputs=image_input, outputs=[
247
+ output_foreground, output_background, output_foreground_mask, output_background_mask])
 
 
 
 
 
 
 
 
 
 
 
248
 
 
249
  demo.launch(debug=True)