Zaiiida's picture
Update app.py
4c76b22 verified
raw
history blame
12.1 kB
import os
import tempfile
import time
from functools import lru_cache
from typing import Any
import gradio as gr
import numpy as np
import rembg
import torch
from gradio_litmodel3d import LitModel3D
from PIL import Image
import sf3d.utils as sf3d_utils
from sf3d.system import SF3D
# Initialize the rembg session
rembg_session = rembg.new_session()
# Constants
COND_WIDTH = 512
COND_HEIGHT = 512
COND_DISTANCE = 1.6
COND_FOVY_DEG = 40
BACKGROUND_COLOR = [0.5, 0.5, 0.5]
# Cached camera parameters
c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
)
# Load the model
model = SF3D.from_pretrained(
"stabilityai/stable-fast-3d",
config_name="config.yaml",
weight_name="model.safetensors",
)
model.eval().cuda()
# Load example files
example_files = [
os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
]
# Define functions
def run_model(input_image):
start = time.time()
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.float16):
model_batch = create_batch(input_image)
model_batch = {k: v.cuda() for k, v in model_batch.items()}
trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024)
trimesh_mesh = trimesh_mesh[0]
# Create new temporary file
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
print("Generation took:", time.time() - start, "s")
return tmp_file.name
def create_batch(input_image: Image) -> dict[str, Any]:
img_cond = (
torch.from_numpy(
np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
/ 255.0
)
.float()
.clip(0, 1)
)
mask_cond = img_cond[:, :, -1:]
rgb_cond = torch.lerp(
torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
)
batch_elem = {
"rgb_cond": rgb_cond,
"mask_cond": mask_cond,
"c2w_cond": c2w_cond.unsqueeze(0),
"intrinsic_cond": intrinsic.unsqueeze(0),
"intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
}
# Add batch dimension
batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
return batched
@lru_cache
def checkerboard(squares: int, size: int, min_value: float = 0.5):
base = np.zeros((squares, squares)) + min_value
base[1::2, ::2] = 1
base[::2, 1::2] = 1
repeat_mult = size // squares
return (
base.repeat(repeat_mult, axis=0)
.repeat(repeat_mult, axis=1)[:, :, None]
.repeat(3, axis=-1)
)
def remove_background(input_image: Image) -> Image:
return rembg.remove(input_image, session=rembg_session)
def resize_foreground(
image: Image,
ratio: float,
) -> Image:
image = np.array(image)
assert image.shape[-1] == 4
alpha = np.where(image[..., 3] > 0)
y1, y2, x1, x2 = (
alpha[0].min(),
alpha[0].max(),
alpha[1].min(),
alpha[1].max(),
)
# Crop the foreground
fg = image[y1:y2, x1:x2]
# Pad to square
size = max(fg.shape[0], fg.shape[1])
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
new_image = np.pad(
fg,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=0,
)
# Compute padding according to the ratio
new_size = int(new_image.shape[0] / ratio)
# Pad to new size
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
new_image = np.pad(
new_image,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=0,
)
new_image = Image.fromarray(new_image, mode="RGBA").resize(
(COND_WIDTH, COND_HEIGHT)
)
return new_image
def square_crop(input_image: Image) -> Image:
# Perform a center square crop
min_size = min(input_image.size)
left = (input_image.size[0] - min_size) // 2
top = (input_image.size[1] - min_size) // 2
right = (input_image.size[0] + min_size) // 2
bottom = (input_image.size[1] + min_size) // 2
return input_image.crop((left, top, right, bottom)).resize(
(COND_WIDTH, COND_HEIGHT)
)
def show_mask_img(input_image: Image) -> Image:
img_numpy = np.array(input_image)
alpha = img_numpy[:, :, 3] / 255.0
chkb = checkerboard(32, 512) * 255
new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
def run_button(run_btn, input_image, background_state, foreground_ratio):
if run_btn == "Run":
glb_file: str = run_model(background_state)
return (
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(value=glb_file, visible=True),
gr.update(visible=True),
)
elif run_btn == "Remove Background":
rem_removed = remove_background(input_image)
sqr_crop = square_crop(rem_removed)
fr_res = resize_foreground(sqr_crop, foreground_ratio)
return (
gr.update(value="Run", visible=True),
sqr_crop,
fr_res,
gr.update(value=show_mask_img(fr_res), visible=True),
gr.update(value=None, visible=False),
gr.update(visible=False),
)
def requires_bg_remove(image, fr):
if image is None:
return (
gr.update(visible=False, value="Run"),
None,
None,
gr.update(value=None, visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
alpha_channel = np.array(image.getchannel("A"))
min_alpha = alpha_channel.min()
if min_alpha == 0:
print("Already has alpha")
sqr_crop = square_crop(image)
fr_res = resize_foreground(sqr_crop, fr)
return (
gr.update(value="Run", visible=True),
sqr_crop,
fr_res,
gr.update(value=show_mask_img(fr_res), visible=True),
gr.update(visible=False),
gr.update(visible=False),
)
return (
gr.update(value="Remove Background", visible=True),
None,
None,
gr.update(value=None, visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
def update_foreground_ratio(img_proc, fr):
foreground_res = resize_foreground(img_proc, fr)
return (
foreground_res,
gr.update(value=show_mask_img(foreground_res)),
)
# Define custom theme
class CustomTheme(gr.themes.Base):
def __init__(self):
super().__init__()
self.primary_hue = "#191a1e"
self.background_fill_primary = "#191a1e"
self.background_fill_secondary = "#191a1e"
self.background_fill_tertiary = "#191a1e"
self.text_color_primary = "#FFFFFF"
self.text_color_secondary = "#FFFFFF"
self.text_color_tertiary = "#FFFFFF"
self.input_background_fill = "#191a1e"
self.input_text_color = "#FFFFFF"
self.font = (
"Poppins",
"https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;700&display=swap",
)
# Custom CSS
css = """
body {
background-color: #191a1e !important;
margin: 0;
padding: 0;
}
/* Применяем фоновый цвет для контейнера Gradio */
.gradio-container {
background-color: #191a1e !important;
}
/* Применяем фоновый цвет для блоков */
.gr-block {
background-color: #191a1e !important;
border: 1px solid #5271FF !important;
}
/* Hide the footer */
footer {
visibility: hidden;
height: 0;
margin: 0;
padding: 0;
overflow: hidden;
}
/* Apply fonts */
body, input, button, textarea, select, .gr-button {
font-family: 'Poppins', sans-serif;
color: #FFFFFF;
}
/* Header styles */
h1, h2, h3, h4, h5, h6 {
font-family: 'Poppins', sans-serif;
font-weight: 700;
color: #FFFFFF;
}
/* Button styles */
.generate-button {
background-color: #5271FF !important;
color: #FFFFFF !important;
border: none;
font-weight: bold;
}
"""
# Build the Gradio interface
with gr.Blocks(theme=CustomTheme(), css=css) as demo:
img_proc_state = gr.State()
background_remove_state = gr.State()
with gr.Row(variant="panel"):
with gr.Column():
with gr.Row():
input_img = gr.Image(
type="pil",
label="Input Image",
sources="upload",
image_mode="RGBA",
)
preview_removal = gr.Image(
label="Preview Background Removal",
type="pil",
image_mode="RGB",
interactive=False,
visible=False,
)
foreground_ratio = gr.Slider(
label="Foreground Ratio",
minimum=0.5,
maximum=1.0,
value=0.85,
step=0.05,
)
foreground_ratio.change(
update_foreground_ratio,
inputs=[img_proc_state, foreground_ratio],
outputs=[background_remove_state, preview_removal],
)
run_btn = gr.Button(
"Run",
variant="primary",
visible=False,
elem_classes="generate-button",
)
with gr.Column():
output_3d = LitModel3D(
label="3D Model",
visible=False,
clear_color=[0.0, 0.0, 0.0, 0.0],
tonemapping="aces",
contrast=1.0,
scale=1.0,
)
with gr.Column(visible=False, scale=1.0) as hdr_row:
gr.Markdown(
"""
## HDR Environment Map
Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps.
"""
)
with gr.Row():
hdr_illumination_file = gr.File(
label="HDR Env Map", file_types=[".hdr"], file_count="single"
)
example_hdris = [
os.path.join("demo_files/hdri", f)
for f in os.listdir("demo_files/hdri")
]
hdr_illumination_example = gr.Examples(
examples=example_hdris,
inputs=hdr_illumination_file,
)
hdr_illumination_file.change(
lambda x: gr.update(env_map=x.name if x is not None else None),
inputs=hdr_illumination_file,
outputs=[output_3d],
)
examples = gr.Examples(
examples=example_files,
inputs=input_img,
)
input_img.change(
requires_bg_remove,
inputs=[input_img, foreground_ratio],
outputs=[
run_btn,
img_proc_state,
background_remove_state,
preview_removal,
output_3d,
hdr_row,
],
)
run_btn.click(
run_button,
inputs=[
run_btn,
input_img,
background_remove_state,
foreground_ratio,
],
outputs=[
run_btn,
img_proc_state,
background_remove_state,
preview_removal,
output_3d,
hdr_row,
],
)
# Launch the interface
demo.launch()