Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -14,20 +14,23 @@ from PIL import Image
|
|
14 |
import sf3d.utils as sf3d_utils
|
15 |
from sf3d.system import SF3D
|
16 |
|
|
|
17 |
rembg_session = rembg.new_session()
|
18 |
|
|
|
19 |
COND_WIDTH = 512
|
20 |
COND_HEIGHT = 512
|
21 |
COND_DISTANCE = 1.6
|
22 |
COND_FOVY_DEG = 40
|
23 |
BACKGROUND_COLOR = [0.5, 0.5, 0.5]
|
24 |
|
25 |
-
# Cached
|
26 |
c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
|
27 |
intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
|
28 |
COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
|
29 |
)
|
30 |
|
|
|
31 |
model = SF3D.from_pretrained(
|
32 |
"stabilityai/stable-fast-3d",
|
33 |
config_name="config.yaml",
|
@@ -35,11 +38,12 @@ model = SF3D.from_pretrained(
|
|
35 |
)
|
36 |
model.eval().cuda()
|
37 |
|
|
|
38 |
example_files = [
|
39 |
os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
|
40 |
]
|
41 |
|
42 |
-
|
43 |
def run_model(input_image):
|
44 |
start = time.time()
|
45 |
with torch.no_grad():
|
@@ -49,16 +53,13 @@ def run_model(input_image):
|
|
49 |
trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024)
|
50 |
trimesh_mesh = trimesh_mesh[0]
|
51 |
|
52 |
-
# Create new
|
53 |
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
|
54 |
-
|
55 |
trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
|
56 |
|
57 |
print("Generation took:", time.time() - start, "s")
|
58 |
-
|
59 |
return tmp_file.name
|
60 |
|
61 |
-
|
62 |
def create_batch(input_image: Image) -> dict[str, Any]:
|
63 |
img_cond = (
|
64 |
torch.from_numpy(
|
@@ -80,11 +81,10 @@ def create_batch(input_image: Image) -> dict[str, Any]:
|
|
80 |
"intrinsic_cond": intrinsic.unsqueeze(0),
|
81 |
"intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
|
82 |
}
|
83 |
-
# Add batch
|
84 |
batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
|
85 |
return batched
|
86 |
|
87 |
-
|
88 |
@lru_cache
|
89 |
def checkerboard(squares: int, size: int, min_value: float = 0.5):
|
90 |
base = np.zeros((squares, squares)) + min_value
|
@@ -98,11 +98,9 @@ def checkerboard(squares: int, size: int, min_value: float = 0.5):
|
|
98 |
.repeat(3, axis=-1)
|
99 |
)
|
100 |
|
101 |
-
|
102 |
def remove_background(input_image: Image) -> Image:
|
103 |
return rembg.remove(input_image, session=rembg_session)
|
104 |
|
105 |
-
|
106 |
def resize_foreground(
|
107 |
image: Image,
|
108 |
ratio: float,
|
@@ -126,26 +124,25 @@ def resize_foreground(
|
|
126 |
fg,
|
127 |
((ph0, ph1), (pw0, pw1), (0, 0)),
|
128 |
mode="constant",
|
129 |
-
constant_values=
|
130 |
)
|
131 |
|
132 |
# Compute padding according to the ratio
|
133 |
new_size = int(new_image.shape[0] / ratio)
|
134 |
-
# Pad to size
|
135 |
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
|
136 |
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
|
137 |
new_image = np.pad(
|
138 |
new_image,
|
139 |
((ph0, ph1), (pw0, pw1), (0, 0)),
|
140 |
mode="constant",
|
141 |
-
constant_values=
|
142 |
)
|
143 |
new_image = Image.fromarray(new_image, mode="RGBA").resize(
|
144 |
(COND_WIDTH, COND_HEIGHT)
|
145 |
)
|
146 |
return new_image
|
147 |
|
148 |
-
|
149 |
def square_crop(input_image: Image) -> Image:
|
150 |
# Perform a center square crop
|
151 |
min_size = min(input_image.size)
|
@@ -157,7 +154,6 @@ def square_crop(input_image: Image) -> Image:
|
|
157 |
(COND_WIDTH, COND_HEIGHT)
|
158 |
)
|
159 |
|
160 |
-
|
161 |
def show_mask_img(input_image: Image) -> Image:
|
162 |
img_numpy = np.array(input_image)
|
163 |
alpha = img_numpy[:, :, 3] / 255.0
|
@@ -165,11 +161,9 @@ def show_mask_img(input_image: Image) -> Image:
|
|
165 |
new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
|
166 |
return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
|
167 |
|
168 |
-
|
169 |
def run_button(run_btn, input_image, background_state, foreground_ratio):
|
170 |
if run_btn == "Run":
|
171 |
glb_file: str = run_model(background_state)
|
172 |
-
|
173 |
return (
|
174 |
gr.update(),
|
175 |
gr.update(),
|
@@ -180,10 +174,8 @@ def run_button(run_btn, input_image, background_state, foreground_ratio):
|
|
180 |
)
|
181 |
elif run_btn == "Remove Background":
|
182 |
rem_removed = remove_background(input_image)
|
183 |
-
|
184 |
sqr_crop = square_crop(rem_removed)
|
185 |
fr_res = resize_foreground(sqr_crop, foreground_ratio)
|
186 |
-
|
187 |
return (
|
188 |
gr.update(value="Run", visible=True),
|
189 |
sqr_crop,
|
@@ -193,7 +185,6 @@ def run_button(run_btn, input_image, background_state, foreground_ratio):
|
|
193 |
gr.update(visible=False),
|
194 |
)
|
195 |
|
196 |
-
|
197 |
def requires_bg_remove(image, fr):
|
198 |
if image is None:
|
199 |
return (
|
@@ -228,7 +219,6 @@ def requires_bg_remove(image, fr):
|
|
228 |
gr.update(visible=False),
|
229 |
)
|
230 |
|
231 |
-
|
232 |
def update_foreground_ratio(img_proc, fr):
|
233 |
foreground_res = resize_foreground(img_proc, fr)
|
234 |
return (
|
@@ -236,38 +226,27 @@ def update_foreground_ratio(img_proc, fr):
|
|
236 |
gr.update(value=show_mask_img(foreground_res)),
|
237 |
)
|
238 |
|
239 |
-
#
|
240 |
-
from gradio.themes.colors import Color
|
241 |
-
|
242 |
-
primary_hue_color = Color(
|
243 |
-
name="custom",
|
244 |
-
c50="#e7e7e8",
|
245 |
-
c100="#c2c2c4",
|
246 |
-
c200="#9d9da0",
|
247 |
-
c300="#78787b",
|
248 |
-
c400="#525357",
|
249 |
-
c500="#2d2e33",
|
250 |
-
c600="#191a1e", # Base color
|
251 |
-
c700="#141517",
|
252 |
-
c800="#0f1012",
|
253 |
-
c900="#0a0a0c",
|
254 |
-
c950="#050506",
|
255 |
-
)
|
256 |
-
|
257 |
-
# Define the custom theme
|
258 |
class CustomTheme(gr.themes.Base):
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
|
|
269 |
css = """
|
270 |
-
/*
|
271 |
body {
|
272 |
background-color: #191a1e !important;
|
273 |
margin: 0;
|
@@ -275,43 +254,46 @@ body {
|
|
275 |
}
|
276 |
|
277 |
/* Apply fonts */
|
278 |
-
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;700&display=swap');
|
279 |
body, input, button, textarea, select, .gr-button {
|
280 |
font-family: 'Poppins', sans-serif;
|
281 |
color: #FFFFFF;
|
282 |
}
|
283 |
|
284 |
-
/* Header
|
285 |
h1, h2, h3, h4, h5, h6 {
|
286 |
font-family: 'Poppins', sans-serif;
|
287 |
font-weight: 700;
|
288 |
color: #FFFFFF;
|
289 |
}
|
290 |
|
291 |
-
/* Button
|
292 |
-
.
|
293 |
background-color: #5271FF !important;
|
294 |
color: #FFFFFF !important;
|
295 |
border: none;
|
296 |
font-weight: bold;
|
297 |
}
|
298 |
|
299 |
-
/* Container
|
300 |
.gr-block {
|
301 |
background-color: #1c1c24 !important;
|
302 |
border: 1px solid #5271FF !important;
|
303 |
}
|
304 |
"""
|
305 |
|
|
|
306 |
with gr.Blocks(theme=CustomTheme(), css=css) as demo:
|
307 |
img_proc_state = gr.State()
|
308 |
background_remove_state = gr.State()
|
309 |
-
|
310 |
with gr.Row(variant="panel"):
|
311 |
with gr.Column():
|
312 |
with gr.Row():
|
313 |
input_img = gr.Image(
|
314 |
-
type="pil",
|
|
|
|
|
|
|
315 |
)
|
316 |
preview_removal = gr.Image(
|
317 |
label="Preview Background Removal",
|
@@ -335,7 +317,12 @@ with gr.Blocks(theme=CustomTheme(), css=css) as demo:
|
|
335 |
outputs=[background_remove_state, preview_removal],
|
336 |
)
|
337 |
|
338 |
-
run_btn = gr.Button(
|
|
|
|
|
|
|
|
|
|
|
339 |
|
340 |
with gr.Column():
|
341 |
output_3d = LitModel3D(
|
@@ -347,10 +334,13 @@ with gr.Blocks(theme=CustomTheme(), css=css) as demo:
|
|
347 |
scale=1.0,
|
348 |
)
|
349 |
with gr.Column(visible=False, scale=1.0) as hdr_row:
|
350 |
-
gr.Markdown(
|
|
|
|
|
351 |
|
352 |
-
Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps.
|
353 |
-
"""
|
|
|
354 |
|
355 |
with gr.Row():
|
356 |
hdr_illumination_file = gr.File(
|
@@ -407,4 +397,5 @@ Select an HDR environment map to light the 3D model. You can also upload your ow
|
|
407 |
],
|
408 |
)
|
409 |
|
|
|
410 |
demo.launch()
|
|
|
14 |
import sf3d.utils as sf3d_utils
|
15 |
from sf3d.system import SF3D
|
16 |
|
17 |
+
# Initialize the rembg session
|
18 |
rembg_session = rembg.new_session()
|
19 |
|
20 |
+
# Constants
|
21 |
COND_WIDTH = 512
|
22 |
COND_HEIGHT = 512
|
23 |
COND_DISTANCE = 1.6
|
24 |
COND_FOVY_DEG = 40
|
25 |
BACKGROUND_COLOR = [0.5, 0.5, 0.5]
|
26 |
|
27 |
+
# Cached camera parameters
|
28 |
c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
|
29 |
intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
|
30 |
COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
|
31 |
)
|
32 |
|
33 |
+
# Load the model
|
34 |
model = SF3D.from_pretrained(
|
35 |
"stabilityai/stable-fast-3d",
|
36 |
config_name="config.yaml",
|
|
|
38 |
)
|
39 |
model.eval().cuda()
|
40 |
|
41 |
+
# Load example files
|
42 |
example_files = [
|
43 |
os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
|
44 |
]
|
45 |
|
46 |
+
# Define functions
|
47 |
def run_model(input_image):
|
48 |
start = time.time()
|
49 |
with torch.no_grad():
|
|
|
53 |
trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024)
|
54 |
trimesh_mesh = trimesh_mesh[0]
|
55 |
|
56 |
+
# Create new temporary file
|
57 |
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
|
|
|
58 |
trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
|
59 |
|
60 |
print("Generation took:", time.time() - start, "s")
|
|
|
61 |
return tmp_file.name
|
62 |
|
|
|
63 |
def create_batch(input_image: Image) -> dict[str, Any]:
|
64 |
img_cond = (
|
65 |
torch.from_numpy(
|
|
|
81 |
"intrinsic_cond": intrinsic.unsqueeze(0),
|
82 |
"intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
|
83 |
}
|
84 |
+
# Add batch dimension
|
85 |
batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
|
86 |
return batched
|
87 |
|
|
|
88 |
@lru_cache
|
89 |
def checkerboard(squares: int, size: int, min_value: float = 0.5):
|
90 |
base = np.zeros((squares, squares)) + min_value
|
|
|
98 |
.repeat(3, axis=-1)
|
99 |
)
|
100 |
|
|
|
101 |
def remove_background(input_image: Image) -> Image:
|
102 |
return rembg.remove(input_image, session=rembg_session)
|
103 |
|
|
|
104 |
def resize_foreground(
|
105 |
image: Image,
|
106 |
ratio: float,
|
|
|
124 |
fg,
|
125 |
((ph0, ph1), (pw0, pw1), (0, 0)),
|
126 |
mode="constant",
|
127 |
+
constant_values=0,
|
128 |
)
|
129 |
|
130 |
# Compute padding according to the ratio
|
131 |
new_size = int(new_image.shape[0] / ratio)
|
132 |
+
# Pad to new size
|
133 |
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
|
134 |
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
|
135 |
new_image = np.pad(
|
136 |
new_image,
|
137 |
((ph0, ph1), (pw0, pw1), (0, 0)),
|
138 |
mode="constant",
|
139 |
+
constant_values=0,
|
140 |
)
|
141 |
new_image = Image.fromarray(new_image, mode="RGBA").resize(
|
142 |
(COND_WIDTH, COND_HEIGHT)
|
143 |
)
|
144 |
return new_image
|
145 |
|
|
|
146 |
def square_crop(input_image: Image) -> Image:
|
147 |
# Perform a center square crop
|
148 |
min_size = min(input_image.size)
|
|
|
154 |
(COND_WIDTH, COND_HEIGHT)
|
155 |
)
|
156 |
|
|
|
157 |
def show_mask_img(input_image: Image) -> Image:
|
158 |
img_numpy = np.array(input_image)
|
159 |
alpha = img_numpy[:, :, 3] / 255.0
|
|
|
161 |
new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
|
162 |
return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
|
163 |
|
|
|
164 |
def run_button(run_btn, input_image, background_state, foreground_ratio):
|
165 |
if run_btn == "Run":
|
166 |
glb_file: str = run_model(background_state)
|
|
|
167 |
return (
|
168 |
gr.update(),
|
169 |
gr.update(),
|
|
|
174 |
)
|
175 |
elif run_btn == "Remove Background":
|
176 |
rem_removed = remove_background(input_image)
|
|
|
177 |
sqr_crop = square_crop(rem_removed)
|
178 |
fr_res = resize_foreground(sqr_crop, foreground_ratio)
|
|
|
179 |
return (
|
180 |
gr.update(value="Run", visible=True),
|
181 |
sqr_crop,
|
|
|
185 |
gr.update(visible=False),
|
186 |
)
|
187 |
|
|
|
188 |
def requires_bg_remove(image, fr):
|
189 |
if image is None:
|
190 |
return (
|
|
|
219 |
gr.update(visible=False),
|
220 |
)
|
221 |
|
|
|
222 |
def update_foreground_ratio(img_proc, fr):
|
223 |
foreground_res = resize_foreground(img_proc, fr)
|
224 |
return (
|
|
|
226 |
gr.update(value=show_mask_img(foreground_res)),
|
227 |
)
|
228 |
|
229 |
+
# Define custom theme
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
class CustomTheme(gr.themes.Base):
|
231 |
+
def __init__(self):
|
232 |
+
super().__init__()
|
233 |
+
self.primary_hue = "#191a1e"
|
234 |
+
self.background_fill_primary = "#191a1e"
|
235 |
+
self.background_fill_secondary = "#191a1e"
|
236 |
+
self.background_fill_tertiary = "#191a1e"
|
237 |
+
self.text_color_primary = "#FFFFFF"
|
238 |
+
self.text_color_secondary = "#FFFFFF"
|
239 |
+
self.text_color_tertiary = "#FFFFFF"
|
240 |
+
self.input_background_fill = "#191a1e"
|
241 |
+
self.input_text_color = "#FFFFFF"
|
242 |
+
self.font = (
|
243 |
+
"Poppins",
|
244 |
+
"https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;700&display=swap",
|
245 |
+
)
|
246 |
|
247 |
+
# Custom CSS
|
248 |
css = """
|
249 |
+
/* Set background color for the entire page */
|
250 |
body {
|
251 |
background-color: #191a1e !important;
|
252 |
margin: 0;
|
|
|
254 |
}
|
255 |
|
256 |
/* Apply fonts */
|
|
|
257 |
body, input, button, textarea, select, .gr-button {
|
258 |
font-family: 'Poppins', sans-serif;
|
259 |
color: #FFFFFF;
|
260 |
}
|
261 |
|
262 |
+
/* Header styles */
|
263 |
h1, h2, h3, h4, h5, h6 {
|
264 |
font-family: 'Poppins', sans-serif;
|
265 |
font-weight: 700;
|
266 |
color: #FFFFFF;
|
267 |
}
|
268 |
|
269 |
+
/* Button styles */
|
270 |
+
.generate-button {
|
271 |
background-color: #5271FF !important;
|
272 |
color: #FFFFFF !important;
|
273 |
border: none;
|
274 |
font-weight: bold;
|
275 |
}
|
276 |
|
277 |
+
/* Container styles */
|
278 |
.gr-block {
|
279 |
background-color: #1c1c24 !important;
|
280 |
border: 1px solid #5271FF !important;
|
281 |
}
|
282 |
"""
|
283 |
|
284 |
+
# Build the Gradio interface
|
285 |
with gr.Blocks(theme=CustomTheme(), css=css) as demo:
|
286 |
img_proc_state = gr.State()
|
287 |
background_remove_state = gr.State()
|
288 |
+
|
289 |
with gr.Row(variant="panel"):
|
290 |
with gr.Column():
|
291 |
with gr.Row():
|
292 |
input_img = gr.Image(
|
293 |
+
type="pil",
|
294 |
+
label="Input Image",
|
295 |
+
sources="upload",
|
296 |
+
image_mode="RGBA",
|
297 |
)
|
298 |
preview_removal = gr.Image(
|
299 |
label="Preview Background Removal",
|
|
|
317 |
outputs=[background_remove_state, preview_removal],
|
318 |
)
|
319 |
|
320 |
+
run_btn = gr.Button(
|
321 |
+
"Run",
|
322 |
+
variant="primary",
|
323 |
+
visible=False,
|
324 |
+
elem_classes="generate-button",
|
325 |
+
)
|
326 |
|
327 |
with gr.Column():
|
328 |
output_3d = LitModel3D(
|
|
|
334 |
scale=1.0,
|
335 |
)
|
336 |
with gr.Column(visible=False, scale=1.0) as hdr_row:
|
337 |
+
gr.Markdown(
|
338 |
+
"""
|
339 |
+
## HDR Environment Map
|
340 |
|
341 |
+
Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps.
|
342 |
+
"""
|
343 |
+
)
|
344 |
|
345 |
with gr.Row():
|
346 |
hdr_illumination_file = gr.File(
|
|
|
397 |
],
|
398 |
)
|
399 |
|
400 |
+
# Launch the interface
|
401 |
demo.launch()
|