Spaces:
Sleeping
Sleeping
ghostsInTheMachine
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,21 +1,15 @@
|
|
1 |
-
import
|
|
|
|
|
2 |
import torch
|
|
|
3 |
import spaces
|
|
|
|
|
|
|
4 |
from transformers import AutoModelForImageSegmentation
|
5 |
from torchvision import transforms
|
6 |
-
import moviepy.editor as mp
|
7 |
-
from PIL import Image
|
8 |
-
import numpy as np
|
9 |
-
import tempfile
|
10 |
-
import time
|
11 |
-
import os
|
12 |
-
import shutil
|
13 |
-
import ffmpeg
|
14 |
-
from concurrent.futures import ThreadPoolExecutor
|
15 |
-
from gradio.themes.base import Base
|
16 |
-
from gradio.themes.utils import colors, fonts
|
17 |
|
18 |
-
# Custom Theme Definition
|
19 |
class WhiteTheme(Base):
|
20 |
def __init__(
|
21 |
self,
|
@@ -65,180 +59,89 @@ class WhiteTheme(Base):
|
|
65 |
shadow_drop="none"
|
66 |
)
|
67 |
|
68 |
-
|
69 |
-
torch.
|
|
|
70 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
birefnet.to(device)
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
pred_pil = transforms.ToPILImage()(pred)
|
111 |
-
pred_pil = pred_pil.resize(original_size, Image.BICUBIC)
|
112 |
-
|
113 |
-
# Create foreground with transparency
|
114 |
-
foreground = image_ori.copy()
|
115 |
-
foreground.putalpha(pred_pil)
|
116 |
-
|
117 |
-
return foreground
|
118 |
-
except Exception as e:
|
119 |
-
print(f"Error processing frame: {e}")
|
120 |
-
return None
|
121 |
-
|
122 |
-
@spaces.GPU(duration=300) # 5-minute duration for processing
|
123 |
-
def process_video(video_path, fps=0, fast_mode=True, max_workers=6):
|
124 |
-
"""
|
125 |
-
Process video to create transparent MOV file using ProRes 4444.
|
126 |
-
Maintains original resolution and framerate if fps=0.
|
127 |
-
"""
|
128 |
-
temp_dir = None
|
129 |
-
try:
|
130 |
-
start_time = time.time()
|
131 |
-
video = mp.VideoFileClip(video_path)
|
132 |
-
|
133 |
-
# Use original video FPS if not specified
|
134 |
-
if fps == 0:
|
135 |
-
fps = video.fps
|
136 |
-
|
137 |
-
frames = list(video.iter_frames(fps=fps))
|
138 |
-
total_frames = len(frames)
|
139 |
-
|
140 |
-
print(f"Processing {total_frames} frames at {fps} FPS...")
|
141 |
-
|
142 |
-
# Create temporary directory for PNG sequence
|
143 |
-
temp_dir = tempfile.mkdtemp()
|
144 |
-
png_dir = os.path.join(temp_dir, "frames")
|
145 |
-
os.makedirs(png_dir, exist_ok=True)
|
146 |
-
|
147 |
-
# Prepare to collect processed frames for live preview
|
148 |
-
processed_frames = []
|
149 |
-
|
150 |
-
# Process frames with parallel execution
|
151 |
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
152 |
-
futures = [executor.submit(process_frame, frame, fast_mode) for frame in frames]
|
153 |
-
for i, future in enumerate(futures):
|
154 |
-
try:
|
155 |
-
result = future.result()
|
156 |
-
if result:
|
157 |
-
# Save frame as PNG with transparency
|
158 |
-
frame_path = os.path.join(png_dir, f"frame_{i:06d}.png")
|
159 |
-
result.save(frame_path, "PNG")
|
160 |
-
|
161 |
-
# Collect processed frames for live preview
|
162 |
-
processed_frames.append(np.array(result))
|
163 |
-
|
164 |
-
# Update live preview
|
165 |
-
elapsed_time = time.time() - start_time
|
166 |
-
yield processed_frames[-1], None, None, None, f"Processing frame {i+1}/{total_frames}... Elapsed time: {elapsed_time:.2f} seconds"
|
167 |
-
|
168 |
-
if (i + 1) % 10 == 0:
|
169 |
-
print(f"Processed {i+1}/{total_frames} frames")
|
170 |
-
except Exception as e:
|
171 |
-
print(f"Error processing frame {i+1}: {e}")
|
172 |
-
|
173 |
-
print("Creating output files...")
|
174 |
-
# Create permanent output directory
|
175 |
-
output_dir = os.path.join(os.path.dirname(video_path), "output")
|
176 |
-
os.makedirs(output_dir, exist_ok=True)
|
177 |
-
|
178 |
-
# Create ZIP file of PNG sequence
|
179 |
-
zip_filename = f"frames_{int(time.time())}.zip"
|
180 |
-
zip_path = os.path.join(output_dir, zip_filename)
|
181 |
-
shutil.make_archive(zip_path[:-4], 'zip', png_dir)
|
182 |
-
|
183 |
-
# Create MOV file with ProRes 4444
|
184 |
-
print("Creating ProRes 4444 MOV...")
|
185 |
-
mov_filename = f"video_{int(time.time())}.mov"
|
186 |
-
mov_path = os.path.join(output_dir, mov_filename)
|
187 |
-
|
188 |
-
try:
|
189 |
-
# FFmpeg settings for high-quality ProRes 4444
|
190 |
-
stream = ffmpeg.input(
|
191 |
-
os.path.join(png_dir, 'frame_%06d.png'),
|
192 |
-
pattern_type='sequence',
|
193 |
-
framerate=fps
|
194 |
-
)
|
195 |
-
|
196 |
-
# ProRes 4444 settings for maximum quality with alpha
|
197 |
-
stream = ffmpeg.output(
|
198 |
-
stream,
|
199 |
-
mov_path,
|
200 |
-
vcodec='prores_ks', # ProRes codec
|
201 |
-
pix_fmt='yuva444p10le', # 10-bit 4:4:4:4 pixel format with alpha
|
202 |
-
profile='4444', # ProRes 4444 profile for alpha support
|
203 |
-
alpha_bits=16, # Maximum alpha bit depth
|
204 |
-
qscale=1, # Highest quality setting
|
205 |
-
vendor='ap10', # Standard ProRes vendor tag
|
206 |
-
bits_per_mb=8000, # High bitrate for quality
|
207 |
-
threads=max_workers # Parallel processing
|
208 |
-
)
|
209 |
-
|
210 |
-
# Run FFmpeg command
|
211 |
-
ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True)
|
212 |
-
print("MOV video created successfully!")
|
213 |
-
|
214 |
-
except ffmpeg.Error as e:
|
215 |
-
print(f"Error creating MOV video: {e.stderr.decode() if e.stderr else str(e)}")
|
216 |
-
mov_path = None
|
217 |
-
|
218 |
-
print("Processing complete!")
|
219 |
-
# Yield the final outputs
|
220 |
-
yield None, zip_path, mov_path, None, f"Processing complete! Total time: {time.time() - start_time:.2f} seconds"
|
221 |
-
|
222 |
-
except Exception as e:
|
223 |
-
print(f"Error: {e}")
|
224 |
-
yield None, None, None, None, f"Error processing video: {e}"
|
225 |
-
finally:
|
226 |
-
# Clean up temporary directory
|
227 |
-
if temp_dir and os.path.exists(temp_dir):
|
228 |
-
try:
|
229 |
-
shutil.rmtree(temp_dir)
|
230 |
-
except Exception as e:
|
231 |
-
print(f"Error cleaning up temp directory: {e}")
|
232 |
-
|
233 |
-
@spaces.GPU(duration=300) # Match process_video duration
|
234 |
-
def process_wrapper(video, fps=0, fast_mode=True, max_workers=6):
|
235 |
-
if video is None:
|
236 |
-
raise gr.Error("Please upload a video.")
|
237 |
-
try:
|
238 |
-
for outputs in process_video(video, fps, fast_mode, max_workers):
|
239 |
-
yield outputs
|
240 |
-
except Exception as e:
|
241 |
-
raise gr.Error(f"Error processing video: {str(e)}")
|
242 |
|
243 |
# Custom CSS for styling
|
244 |
custom_css = """
|
@@ -301,7 +204,6 @@ custom_css = """
|
|
301 |
}
|
302 |
"""
|
303 |
|
304 |
-
# Gradio Interface
|
305 |
with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
|
306 |
gr.HTML('''
|
307 |
<div class="title-container">
|
@@ -311,7 +213,7 @@ with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
|
|
311 |
</div>
|
312 |
<script>
|
313 |
(function() {
|
314 |
-
const text = "
|
315 |
const typedTextSpan = document.getElementById("typed-text");
|
316 |
let charIndex = 0;
|
317 |
|
@@ -328,57 +230,20 @@ with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
|
|
328 |
</script>
|
329 |
''')
|
330 |
|
|
|
331 |
with gr.Row():
|
332 |
with gr.Column():
|
333 |
-
|
334 |
-
|
335 |
-
interactive=True,
|
336 |
-
show_label=True,
|
337 |
-
height=360,
|
338 |
-
width=640
|
339 |
-
)
|
340 |
-
with gr.Row():
|
341 |
-
fps_slider = gr.Slider(
|
342 |
-
minimum=0,
|
343 |
-
maximum=60,
|
344 |
-
step=1,
|
345 |
-
value=0,
|
346 |
-
label="Output FPS (0 will inherit the original fps value)",
|
347 |
-
)
|
348 |
-
fast_mode_checkbox = gr.Checkbox(
|
349 |
-
label="Fast Mode (Use BiRefNet_lite)",
|
350 |
-
value=True
|
351 |
-
)
|
352 |
-
max_workers_slider = gr.Slider(
|
353 |
-
minimum=1,
|
354 |
-
maximum=32,
|
355 |
-
step=1,
|
356 |
-
value=6,
|
357 |
-
label="Max Workers",
|
358 |
-
info="Determines how many frames to process in parallel"
|
359 |
-
)
|
360 |
-
btn = gr.Button("Process Video", elem_id="submit-button")
|
361 |
|
362 |
with gr.Column():
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
- MOV file uses ProRes 4444 codec for professional-grade alpha channel
|
372 |
-
- Original resolution and framerate are maintained
|
373 |
-
- PNG sequence provided for maximum compatibility
|
374 |
-
""")
|
375 |
-
|
376 |
-
btn.click(
|
377 |
-
fn=process_wrapper,
|
378 |
-
inputs=[video_input, fps_slider, fast_mode_checkbox, max_workers_slider],
|
379 |
-
outputs=[preview_image, output_foreground_zip, output_foreground_video,
|
380 |
-
output_background, time_textbox]
|
381 |
-
)
|
382 |
|
383 |
-
if __name__ == "__main__":
|
384 |
demo.launch(debug=True)
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
import torch
|
5 |
+
import gradio as gr
|
6 |
import spaces
|
7 |
+
from gradio.themes.base import Base
|
8 |
+
from gradio.themes.utils import colors, fonts, sizes
|
9 |
+
from PIL import Image, ImageOps
|
10 |
from transformers import AutoModelForImageSegmentation
|
11 |
from torchvision import transforms
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
|
|
13 |
class WhiteTheme(Base):
|
14 |
def __init__(
|
15 |
self,
|
|
|
59 |
shadow_drop="none"
|
60 |
)
|
61 |
|
62 |
+
torch.set_float32_matmul_precision('high')
|
63 |
+
torch.jit.script = lambda f: f
|
64 |
+
|
65 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
66 |
|
67 |
+
def refine_foreground(image, mask, r=90):
|
68 |
+
if mask.size != image.size:
|
69 |
+
mask = mask.resize(image.size)
|
70 |
+
image = np.array(image) / 255.0
|
71 |
+
mask = np.array(mask) / 255.0
|
72 |
+
estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
|
73 |
+
image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
|
74 |
+
return image_masked
|
75 |
+
|
76 |
+
def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
|
77 |
+
alpha = alpha[:, :, None]
|
78 |
+
F, blur_B = FB_blur_fusion_foreground_estimator(
|
79 |
+
image, image, image, alpha, r)
|
80 |
+
return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
|
81 |
+
|
82 |
+
def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
|
83 |
+
if isinstance(image, Image.Image):
|
84 |
+
image = np.array(image) / 255.0
|
85 |
+
blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
|
86 |
+
blurred_FA = cv2.blur(F * alpha, (r, r))
|
87 |
+
blurred_F = blurred_FA / (blurred_alpha + 1e-5)
|
88 |
+
blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
|
89 |
+
blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
|
90 |
+
F = blurred_F + alpha * (image - alpha * blurred_F - (1 - alpha) * blurred_B)
|
91 |
+
F = np.clip(F, 0, 1)
|
92 |
+
return F, blurred_B
|
93 |
+
|
94 |
+
class ImagePreprocessor():
|
95 |
+
def __init__(self, resolution=(1024, 1024)) -> None:
|
96 |
+
self.transform_image = transforms.Compose([
|
97 |
+
transforms.Resize(resolution),
|
98 |
+
transforms.ToTensor(),
|
99 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
100 |
+
[0.229, 0.224, 0.225]),
|
101 |
+
])
|
102 |
+
|
103 |
+
def proc(self, image: Image.Image) -> torch.Tensor:
|
104 |
+
image = self.transform_image(image)
|
105 |
+
return image
|
106 |
+
|
107 |
+
# Load the model
|
108 |
+
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
109 |
+
'zhengpeng7/BiRefNet-matting', trust_remote_code=True)
|
110 |
birefnet.to(device)
|
111 |
+
birefnet.eval()
|
112 |
+
|
113 |
+
def remove_background_wrapper(image):
|
114 |
+
if image is None:
|
115 |
+
raise gr.Error("Please upload an image.")
|
116 |
+
image_ori = Image.fromarray(image).convert('RGB')
|
117 |
+
foreground, background, pred_pil, reverse_mask = remove_background(image_ori)
|
118 |
+
return foreground, background, pred_pil, reverse_mask
|
119 |
+
|
120 |
+
@spaces.GPU
|
121 |
+
def remove_background(image_ori):
|
122 |
+
original_size = image_ori.size
|
123 |
+
image_preprocessor = ImagePreprocessor(resolution=(1024, 1024))
|
124 |
+
image_proc = image_preprocessor.proc(image_ori)
|
125 |
+
image_proc = image_proc.unsqueeze(0)
|
126 |
+
|
127 |
+
with torch.no_grad():
|
128 |
+
preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
|
129 |
+
pred = preds[0].squeeze()
|
130 |
+
|
131 |
+
pred_pil = transforms.ToPILImage()(pred)
|
132 |
+
pred_pil = pred_pil.resize(original_size, Image.BICUBIC)
|
133 |
+
|
134 |
+
reverse_mask = ImageOps.invert(pred_pil)
|
135 |
+
|
136 |
+
foreground = image_ori.copy()
|
137 |
+
foreground.putalpha(pred_pil)
|
138 |
+
|
139 |
+
background = image_ori.copy()
|
140 |
+
background.putalpha(reverse_mask)
|
141 |
+
|
142 |
+
torch.cuda.empty_cache()
|
143 |
+
|
144 |
+
return foreground, background, pred_pil, reverse_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
# Custom CSS for styling
|
147 |
custom_css = """
|
|
|
204 |
}
|
205 |
"""
|
206 |
|
|
|
207 |
with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
|
208 |
gr.HTML('''
|
209 |
<div class="title-container">
|
|
|
213 |
</div>
|
214 |
<script>
|
215 |
(function() {
|
216 |
+
const text = "image";
|
217 |
const typedTextSpan = document.getElementById("typed-text");
|
218 |
let charIndex = 0;
|
219 |
|
|
|
230 |
</script>
|
231 |
''')
|
232 |
|
233 |
+
# Interface setup with input and output
|
234 |
with gr.Row():
|
235 |
with gr.Column():
|
236 |
+
image_input = gr.Image(type="numpy", sources=['upload'], label="Upload Image")
|
237 |
+
btn = gr.Button("Process Image", elem_id="submit-button")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
|
239 |
with gr.Column():
|
240 |
+
output_foreground = gr.Image(type="pil", label="Foreground")
|
241 |
+
output_background = gr.Image(type="pil", label="Background")
|
242 |
+
output_foreground_mask = gr.Image(type="pil", label="Foreground Mask")
|
243 |
+
output_background_mask = gr.Image(type="pil", label="Background Mask")
|
244 |
+
|
245 |
+
# Link the button to the processing function
|
246 |
+
btn.click(fn=remove_background_wrapper, inputs=image_input, outputs=[
|
247 |
+
output_foreground, output_background, output_foreground_mask, output_background_mask])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
|
|
|
249 |
demo.launch(debug=True)
|