Spaces:
Paused
Paused
import os | |
import tempfile | |
import time | |
from functools import lru_cache | |
from typing import Any | |
import gradio as gr | |
import numpy as np | |
import rembg | |
import torch | |
from gradio_litmodel3d import LitModel3D | |
from PIL import Image | |
import sf3d.utils as sf3d_utils | |
from sf3d.system import SF3D | |
# Initialize the rembg session | |
rembg_session = rembg.new_session() | |
# Constants | |
COND_WIDTH = 512 | |
COND_HEIGHT = 512 | |
COND_DISTANCE = 1.6 | |
COND_FOVY_DEG = 40 | |
BACKGROUND_COLOR = [0.5, 0.5, 0.5] | |
# Cached camera parameters | |
c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE) | |
intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg( | |
COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH | |
) | |
# Load the model | |
model = SF3D.from_pretrained( | |
"stabilityai/stable-fast-3d", | |
config_name="config.yaml", | |
weight_name="model.safetensors", | |
) | |
model.eval().cuda() | |
# Load example files | |
example_files = [ | |
os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples") | |
] | |
# Define functions | |
def run_model(input_image): | |
start = time.time() | |
with torch.no_grad(): | |
with torch.autocast(device_type="cuda", dtype=torch.float16): | |
model_batch = create_batch(input_image) | |
model_batch = {k: v.cuda() for k, v in model_batch.items()} | |
trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024) | |
trimesh_mesh = trimesh_mesh[0] | |
# Create new temporary file | |
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb") | |
trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True) | |
print("Generation took:", time.time() - start, "s") | |
return tmp_file.name | |
def create_batch(input_image: Image) -> dict[str, Any]: | |
img_cond = ( | |
torch.from_numpy( | |
np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32) | |
/ 255.0 | |
) | |
.float() | |
.clip(0, 1) | |
) | |
mask_cond = img_cond[:, :, -1:] | |
rgb_cond = torch.lerp( | |
torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond | |
) | |
batch_elem = { | |
"rgb_cond": rgb_cond, | |
"mask_cond": mask_cond, | |
"c2w_cond": c2w_cond.unsqueeze(0), | |
"intrinsic_cond": intrinsic.unsqueeze(0), | |
"intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0), | |
} | |
# Add batch dimension | |
batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()} | |
return batched | |
def checkerboard(squares: int, size: int, min_value: float = 0.5): | |
base = np.zeros((squares, squares)) + min_value | |
base[1::2, ::2] = 1 | |
base[::2, 1::2] = 1 | |
repeat_mult = size // squares | |
return ( | |
base.repeat(repeat_mult, axis=0) | |
.repeat(repeat_mult, axis=1)[:, :, None] | |
.repeat(3, axis=-1) | |
) | |
def remove_background(input_image: Image) -> Image: | |
return rembg.remove(input_image, session=rembg_session) | |
def resize_foreground( | |
image: Image, | |
ratio: float, | |
) -> Image: | |
image = np.array(image) | |
assert image.shape[-1] == 4 | |
alpha = np.where(image[..., 3] > 0) | |
y1, y2, x1, x2 = ( | |
alpha[0].min(), | |
alpha[0].max(), | |
alpha[1].min(), | |
alpha[1].max(), | |
) | |
# Crop the foreground | |
fg = image[y1:y2, x1:x2] | |
# Pad to square | |
size = max(fg.shape[0], fg.shape[1]) | |
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 | |
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 | |
new_image = np.pad( | |
fg, | |
((ph0, ph1), (pw0, pw1), (0, 0)), | |
mode="constant", | |
constant_values=0, | |
) | |
# Compute padding according to the ratio | |
new_size = int(new_image.shape[0] / ratio) | |
# Pad to new size | |
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 | |
ph1, pw1 = new_size - size - ph0, new_size - size - pw0 | |
new_image = np.pad( | |
new_image, | |
((ph0, ph1), (pw0, pw1), (0, 0)), | |
mode="constant", | |
constant_values=0, | |
) | |
new_image = Image.fromarray(new_image, mode="RGBA").resize( | |
(COND_WIDTH, COND_HEIGHT) | |
) | |
return new_image | |
def square_crop(input_image: Image) -> Image: | |
# Perform a center square crop | |
min_size = min(input_image.size) | |
left = (input_image.size[0] - min_size) // 2 | |
top = (input_image.size[1] - min_size) // 2 | |
right = (input_image.size[0] + min_size) // 2 | |
bottom = (input_image.size[1] + min_size) // 2 | |
return input_image.crop((left, top, right, bottom)).resize( | |
(COND_WIDTH, COND_HEIGHT) | |
) | |
def show_mask_img(input_image: Image) -> Image: | |
img_numpy = np.array(input_image) | |
alpha = img_numpy[:, :, 3] / 255.0 | |
chkb = checkerboard(32, 512) * 255 | |
new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None]) | |
return Image.fromarray(new_img.astype(np.uint8), mode="RGB") | |
def run_button(run_btn, input_image, background_state, foreground_ratio): | |
if run_btn == "Run": | |
glb_file: str = run_model(background_state) | |
return ( | |
gr.update(), | |
gr.update(), | |
gr.update(), | |
gr.update(), | |
gr.update(value=glb_file, visible=True), | |
gr.update(visible=True), | |
) | |
elif run_btn == "Remove Background": | |
rem_removed = remove_background(input_image) | |
sqr_crop = square_crop(rem_removed) | |
fr_res = resize_foreground(sqr_crop, foreground_ratio) | |
return ( | |
gr.update(value="Run", visible=True), | |
sqr_crop, | |
fr_res, | |
gr.update(value=show_mask_img(fr_res), visible=True), | |
gr.update(value=None, visible=False), | |
gr.update(visible=False), | |
) | |
def requires_bg_remove(image, fr): | |
if image is None: | |
return ( | |
gr.update(visible=False, value="Run"), | |
None, | |
None, | |
gr.update(value=None, visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
) | |
alpha_channel = np.array(image.getchannel("A")) | |
min_alpha = alpha_channel.min() | |
if min_alpha == 0: | |
print("Already has alpha") | |
sqr_crop = square_crop(image) | |
fr_res = resize_foreground(sqr_crop, fr) | |
return ( | |
gr.update(value="Run", visible=True), | |
sqr_crop, | |
fr_res, | |
gr.update(value=show_mask_img(fr_res), visible=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
) | |
return ( | |
gr.update(value="Remove Background", visible=True), | |
None, | |
None, | |
gr.update(value=None, visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
) | |
def update_foreground_ratio(img_proc, fr): | |
foreground_res = resize_foreground(img_proc, fr) | |
return ( | |
foreground_res, | |
gr.update(value=show_mask_img(foreground_res)), | |
) | |
# Define custom theme | |
class CustomTheme(gr.themes.Base): | |
def __init__(self): | |
super().__init__() | |
self.primary_hue = "#191a1e" | |
self.background_fill_primary = "#191a1e" | |
self.background_fill_secondary = "#191a1e" | |
self.background_fill_tertiary = "#191a1e" | |
self.text_color_primary = "#FFFFFF" | |
self.text_color_secondary = "#FFFFFF" | |
self.text_color_tertiary = "#FFFFFF" | |
self.input_background_fill = "#191a1e" | |
self.input_text_color = "#FFFFFF" | |
self.font = ( | |
"Poppins", | |
"https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;700&display=swap", | |
) | |
# Custom CSS | |
css = """ | |
body { | |
background-color: #191a1e !important; | |
margin: 0; | |
padding: 0; | |
} | |
/* Применяем фоновый цвет для контейнера Gradio */ | |
.gradio-container { | |
background-color: #191a1e !important; | |
} | |
/* Применяем фоновый цвет для блоков */ | |
.gr-block { | |
background-color: #191a1e !important; | |
border: 1px solid #5271FF !important; | |
} | |
/* Hide the footer */ | |
footer { | |
visibility: hidden; | |
height: 0; | |
margin: 0; | |
padding: 0; | |
overflow: hidden; | |
} | |
/* Apply fonts */ | |
body, input, button, textarea, select, .gr-button { | |
font-family: 'Poppins', sans-serif; | |
color: #FFFFFF; | |
} | |
/* Header styles */ | |
h1, h2, h3, h4, h5, h6 { | |
font-family: 'Poppins', sans-serif; | |
font-weight: 700; | |
color: #FFFFFF; | |
} | |
/* Button styles */ | |
.generate-button { | |
background-color: #5271FF !important; | |
color: #FFFFFF !important; | |
border: none; | |
font-weight: bold; | |
} | |
""" | |
# Build the Gradio interface | |
with gr.Blocks(theme=CustomTheme(), css=css) as demo: | |
img_proc_state = gr.State() | |
background_remove_state = gr.State() | |
with gr.Row(variant="panel"): | |
with gr.Column(): | |
with gr.Row(): | |
input_img = gr.Image( | |
type="pil", | |
label="Input Image", | |
sources="upload", | |
image_mode="RGBA", | |
) | |
preview_removal = gr.Image( | |
label="Preview Background Removal", | |
type="pil", | |
image_mode="RGB", | |
interactive=False, | |
visible=False, | |
) | |
foreground_ratio = gr.Slider( | |
label="Foreground Ratio", | |
minimum=0.5, | |
maximum=1.0, | |
value=0.85, | |
step=0.05, | |
) | |
foreground_ratio.change( | |
update_foreground_ratio, | |
inputs=[img_proc_state, foreground_ratio], | |
outputs=[background_remove_state, preview_removal], | |
) | |
run_btn = gr.Button( | |
"Run", | |
variant="primary", | |
visible=False, | |
elem_classes="generate-button", | |
) | |
with gr.Column(): | |
output_3d = LitModel3D( | |
label="3D Model", | |
visible=False, | |
clear_color=[0.0, 0.0, 0.0, 0.0], | |
tonemapping="aces", | |
contrast=1.0, | |
scale=1.0, | |
) | |
with gr.Column(visible=False, scale=1.0) as hdr_row: | |
gr.Markdown( | |
""" | |
## HDR Environment Map | |
Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps. | |
""" | |
) | |
with gr.Row(): | |
hdr_illumination_file = gr.File( | |
label="HDR Env Map", file_types=[".hdr"], file_count="single" | |
) | |
example_hdris = [ | |
os.path.join("demo_files/hdri", f) | |
for f in os.listdir("demo_files/hdri") | |
] | |
hdr_illumination_example = gr.Examples( | |
examples=example_hdris, | |
inputs=hdr_illumination_file, | |
) | |
hdr_illumination_file.change( | |
lambda x: gr.update(env_map=x.name if x is not None else None), | |
inputs=hdr_illumination_file, | |
outputs=[output_3d], | |
) | |
examples = gr.Examples( | |
examples=example_files, | |
inputs=input_img, | |
) | |
input_img.change( | |
requires_bg_remove, | |
inputs=[input_img, foreground_ratio], | |
outputs=[ | |
run_btn, | |
img_proc_state, | |
background_remove_state, | |
preview_removal, | |
output_3d, | |
hdr_row, | |
], | |
) | |
run_btn.click( | |
run_button, | |
inputs=[ | |
run_btn, | |
input_img, | |
background_remove_state, | |
foreground_ratio, | |
], | |
outputs=[ | |
run_btn, | |
img_proc_state, | |
background_remove_state, | |
preview_removal, | |
output_3d, | |
hdr_row, | |
], | |
) | |
# Launch the interface | |
demo.launch() | |