LHM / app.py
DyrusQZ's picture
move detail func in app
b206b0b
raw
history blame
14.6 kB
# Copyright (c) 2023-2024, Qi Zuo
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from PIL import Image
import numpy as np
import gradio as gr
import base64
import spaces
import subprocess
import os
from engine.pose_estimation.pose_estimator import PoseEstimator
from LHM.utils.face_detector import VGGHeadDetector
from LHM.utils.hf_hub import wrap_model_hub
def parse_configs():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str)
parser.add_argument("--infer", type=str)
args, unknown = parser.parse_known_args()
cfg = OmegaConf.create()
cli_cfg = OmegaConf.from_cli(unknown)
# parse from ENV
if os.environ.get("APP_INFER") is not None:
args.infer = os.environ.get("APP_INFER")
if os.environ.get("APP_MODEL_NAME") is not None:
cli_cfg.model_name = os.environ.get("APP_MODEL_NAME")
args.config = args.infer if args.config is None else args.config
if args.config is not None:
cfg_train = OmegaConf.load(args.config)
cfg.source_size = cfg_train.dataset.source_image_res
try:
cfg.src_head_size = cfg_train.dataset.src_head_size
except:
cfg.src_head_size = 112
cfg.render_size = cfg_train.dataset.render_image.high
_relative_path = os.path.join(
cfg_train.experiment.parent,
cfg_train.experiment.child,
os.path.basename(cli_cfg.model_name).split("_")[-1],
)
cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path)
cfg.image_dump = os.path.join("exps", "images", _relative_path)
cfg.video_dump = os.path.join("exps", "videos", _relative_path) # output path
if args.infer is not None:
cfg_infer = OmegaConf.load(args.infer)
cfg.merge_with(cfg_infer)
cfg.setdefault(
"save_tmp_dump", os.path.join("exps", cli_cfg.model_name, "save_tmp")
)
cfg.setdefault("image_dump", os.path.join("exps", cli_cfg.model_name, "images"))
cfg.setdefault(
"video_dump", os.path.join("dumps", cli_cfg.model_name, "videos")
)
cfg.setdefault("mesh_dump", os.path.join("dumps", cli_cfg.model_name, "meshes"))
cfg.motion_video_read_fps = 6
cfg.merge_with(cli_cfg)
cfg.setdefault("logger", "INFO")
assert cfg.model_name is not None, "model_name is required"
return cfg, cfg_train
def _build_model(cfg):
from LHM.models import model_dict
hf_model_cls = wrap_model_hub(model_dict["human_lrm_sapdino_bh_sd3_5"])
model = hf_model_cls.from_pretrained(cfg.model_name)
return model
def launch_pretrained():
from huggingface_hub import snapshot_download, hf_hub_download
hf_hub_download(repo_id="DyrusQZ/LHM_Runtime", repo_type='model', filename='assets.tar', local_dir="./")
os.system("tar -xvf assets.tar && rm assets.tar")
hf_hub_download(repo_id="DyrusQZ/LHM_Runtime", repo_type='model', filename='LHM-0.5B.tar', local_dir="./")
os.system("tar -xvf LHM-0.5B.tar && rm LHM-0.5B.tar")
hf_hub_download(repo_id="DyrusQZ/LHM_Runtime", repo_type='model', filename='LHM_prior_model.tar', local_dir="./")
os.system("tar -xvf LHM_prior_model.tar && rm LHM_prior_model.tar")
def launch_env_not_compile_with_cuda():
os.system("pip install chumpy")
os.system("pip uninstall -y basicsr")
os.system("pip install git+https://github.com/hitsz-zuoqi/BasicSR/")
# os.system("pip install -e ./third_party/sam2")
os.system("pip install numpy==1.23.0")
# os.system("pip install git+https://github.com/hitsz-zuoqi/sam2/")
# os.system("pip install git+https://github.com/ashawkey/diff-gaussian-rasterization/")
# os.system("pip install git+https://github.com/camenduru/simple-knn/")
os.system("pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt251/download.html")
def assert_input_image(input_image):
if input_image is None:
raise gr.Error("No image selected or uploaded!")
def prepare_working_dir():
import tempfile
working_dir = tempfile.TemporaryDirectory()
return working_dir
def init_preprocessor():
from LHM.utils.preprocess import Preprocessor
global preprocessor
preprocessor = Preprocessor()
def preprocess_fn(image_in: np.ndarray, remove_bg: bool, recenter: bool, working_dir):
image_raw = os.path.join(working_dir.name, "raw.png")
with Image.fromarray(image_in) as img:
img.save(image_raw)
image_out = os.path.join(working_dir.name, "rembg.png")
success = preprocessor.preprocess(image_path=image_raw, save_path=image_out, rmbg=remove_bg, recenter=recenter)
assert success, f"Failed under preprocess_fn!"
return image_out
def get_image_base64(path):
with open(path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode()
return f"data:image/png;base64,{encoded_string}"
def demo_lhm(pose_estimator, face_detector, lhm_model, cfg):
@spaces.GPU
def core_fn(image: str, video_params, working_dir):
image_raw = os.path.join(working_dir.name, "raw.png")
with Image.fromarray(image) as img:
img.save(image_raw)
base_vid = os.path.basename(video_params).split("_")[0]
smplx_params_dir = os.path.join("./assets/sample_motion", base_vid, "smplx_params")
dump_video_path = os.path.join(working_dir.name, "output.mp4")
dump_image_path = os.path.join(working_dir.name, "output.png")
# prepare dump paths
omit_prefix = os.path.dirname(image_raw)
image_name = os.path.basename(image_raw)
uid = image_name.split(".")[0]
subdir_path = os.path.dirname(image_raw).replace(omit_prefix, "")
subdir_path = (
subdir_path[1:] if subdir_path.startswith("/") else subdir_path
)
print("subdir_path and uid:", subdir_path, uid)
motion_seqs_dir = smplx_params_dir
motion_name = os.path.dirname(
motion_seqs_dir[:-1] if motion_seqs_dir[-1] == "/" else motion_seqs_dir
)
motion_name = os.path.basename(motion_name)
dump_image_dir = os.path.dirname(dump_image_path)
os.makedirs(dump_image_dir, exist_ok=True)
print(image_raw, motion_seqs_dir, dump_image_dir, dump_video_path)
shape_pose = pose_estimator(image_raw)
assert shape_pose.is_full_body, f"The input image is illegal, {shape_pose.msg}"
if os.path.exists(dump_video_path):
return dump_image_path, dump_video_path
source_size = cfg.source_size
render_size = cfg.render_size
render_fps = 30
aspect_standard = 5.0 / 3
motion_img_need_mask = cfg.get("motion_img_need_mask", False) # False
vis_motion = cfg.get("vis_motion", False) # False
parsing_mask = parsing(image_raw)
input = cv2.imread(img_path)
output = remove(input)
alpha = output[:,:,3]
# self.infer_single(
# image_path,
# motion_seqs_dir=motion_seqs_dir,
# motion_img_dir=None,
# motion_video_read_fps=30,
# export_video=False,
# export_mesh=False,
# dump_tmp_dir=dump_image_dir,
# dump_image_dir=dump_image_dir,
# dump_video_path=dump_video_path,
# shape_param=shape_pose.beta,
# )
# status = spaces.GPU(infer_impl(
# gradio_demo_image=image_raw,
# gradio_motion_file=smplx_params_dir,
# gradio_masked_image=dump_image_path,
# gradio_video_save_path=dump_video_path
# ))
# if status:
# return dump_image_path, dump_video_path
# else:
# return None, None
_TITLE = '''LHM: Large Animatable Human Model'''
_DESCRIPTION = '''
<strong>Reconstruct a human avatar in 0.2 seconds with A100!</strong>
'''
with gr.Blocks(analytics_enabled=False) as demo:
# </div>
logo_url = "./assets/rgba_logo_new.png"
logo_base64 = get_image_base64(logo_url)
gr.HTML(
f"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1> <img src="{logo_base64}" style='height:35px; display:inline-block;'/> Large Animatable Human Model </h1>
</div>
</div>
"""
)
gr.HTML(
"""<p><h4 style="color: red;"> Notes: Please input full-body image in case of detection errors.</h4></p>"""
)
# DISPLAY
with gr.Row():
with gr.Column(variant='panel', scale=1):
with gr.Tabs(elem_id="openlrm_input_image"):
with gr.TabItem('Input Image'):
with gr.Row():
input_image = gr.Image(label="Input Image", image_mode="RGBA", height=480, width=270, sources="upload", type="numpy", elem_id="content_image")
# EXAMPLES
with gr.Row():
examples = [
['assets/sample_input/joker.jpg'],
['assets/sample_input/anime.png'],
['assets/sample_input/basket.png'],
['assets/sample_input/ai_woman1.JPG'],
['assets/sample_input/anime2.JPG'],
['assets/sample_input/anime3.JPG'],
['assets/sample_input/boy1.png'],
['assets/sample_input/choplin.jpg'],
['assets/sample_input/eins.JPG'],
['assets/sample_input/girl1.png'],
['assets/sample_input/girl2.png'],
['assets/sample_input/robot.jpg'],
]
gr.Examples(
examples=examples,
inputs=[input_image],
examples_per_page=20,
)
with gr.Column():
with gr.Tabs(elem_id="openlrm_input_video"):
with gr.TabItem('Input Video'):
with gr.Row():
video_input = gr.Video(label="Input Video",height=480, width=270, interactive=False)
examples = [
# './assets/sample_motion/danaotiangong/danaotiangong_origin.mp4',
'./assets/sample_motion/ex5/ex5_origin.mp4',
'./assets/sample_motion/girl2/girl2_origin.mp4',
'./assets/sample_motion/jntm/jntm_origin.mp4',
'./assets/sample_motion/mimo1/mimo1_origin.mp4',
'./assets/sample_motion/mimo2/mimo2_origin.mp4',
'./assets/sample_motion/mimo4/mimo4_origin.mp4',
'./assets/sample_motion/mimo5/mimo5_origin.mp4',
'./assets/sample_motion/mimo6/mimo6_origin.mp4',
'./assets/sample_motion/nezha/nezha_origin.mp4',
'./assets/sample_motion/taiji/taiji_origin.mp4'
]
gr.Examples(
examples=examples,
inputs=[video_input],
examples_per_page=20,
)
with gr.Column(variant='panel', scale=1):
with gr.Tabs(elem_id="openlrm_processed_image"):
with gr.TabItem('Processed Image'):
with gr.Row():
processed_image = gr.Image(label="Processed Image", image_mode="RGBA", type="filepath", elem_id="processed_image", height=480, width=270, interactive=False)
with gr.Column(variant='panel', scale=1):
with gr.Tabs(elem_id="openlrm_render_video"):
with gr.TabItem('Rendered Video'):
with gr.Row():
output_video = gr.Video(label="Rendered Video", format="mp4", height=480, width=270, autoplay=True)
# SETTING
with gr.Row():
with gr.Column(variant='panel', scale=1):
submit = gr.Button('Generate', elem_id="openlrm_generate", variant='primary')
working_dir = gr.State()
submit.click(
fn=assert_input_image,
inputs=[input_image],
queue=False,
).success(
fn=prepare_working_dir,
outputs=[working_dir],
queue=False,
).success(
fn=core_fn,
inputs=[input_image, video_input, working_dir], # video_params refer to smpl dir
outputs=[processed_image, output_video],
)
demo.queue()
demo.launch()
def launch_gradio_app():
os.environ.update({
"APP_ENABLED": "1",
"APP_MODEL_NAME": "./exps/releases/video_human_benchmark/human-lrm-500M/step_060000/",
"APP_INFER": "./configs/inference/human-lrm-500M.yaml",
"APP_TYPE": "infer.human_lrm",
"NUMBA_THREADING_LAYER": 'omp',
})
# from LHM.runners import REGISTRY_RUNNERS
# RunnerClass = REGISTRY_RUNNERS[os.getenv("APP_TYPE")]
# with RunnerClass() as runner:
# runner.to('cuda')
# demo_lhm(infer_impl=runner.infer)
facedetector = VGGHeadDetector(
"./pretrained_models/gagatracker/vgghead/vgg_heads_l.trcd",
device='cpu',
)
facedetector.to('cuda')
pose_estimator = PoseEstimator(
"./pretrained_models/human_model_files/", device='cpu'
)
pose_estimator.to('cuda')
pose_estimator.device = 'cuda'
cfg, cfg_train = parse_configs()
lhm = _build_model(cfg)
lhm.to('cuda')
demo_lhm(pose_estimator, facedetector, lhm, cfg)
if __name__ == '__main__':
# launch_pretrained()
# launch_env_not_compile_with_cuda()
launch_gradio_app()
# import gradio as gr
# def greet(name):
# return "Hello " + name + "!!"
# demo = gr.Interface(fn=greet, inputs="text", outputs="text")
# demo.launch()