poser-tf / app.py
leonelhs's picture
show pose body lables
9c8f48a
raw
history blame
2.39 kB
import PIL.Image
import PIL.ImageOps
import gradio as gr
import numpy as np
import tensorflow as tf
from poser import draw_bones, movenet
def predict(image: PIL.Image):
input_size = 256
size = (1280, 1280)
image = PIL.ImageOps.fit(image, size, PIL.Image.LANCZOS)
# image = PIL.ImageOps.contain(image, size)
image_tf = tf.keras.preprocessing.image.img_to_array(image)
# Resize and pad the image to keep the aspect ratio and fit the expected size.
input_image = tf.expand_dims(image_tf, axis=0)
input_image = tf.image.resize_with_pad(input_image, input_size, input_size)
keypoints = movenet(input_image)
keypoints = np.array(keypoints)
image = tf.keras.preprocessing.image.array_to_img(image_tf)
joints = draw_bones(image, keypoints)
points = [f"{x}#{y}" for p, x, y in joints]
return image, joints, points
footer = r"""
<center>
<b>
Demo for <a href='https://www.tensorflow.org/hub/tutorials/movenet'>MoveNet</a>
</b>
</center>
"""
with gr.Blocks(title="MoveNet") as app:
gr.HTML("<center><h1>Human Pose Estimation with MoveNet</h1></center>")
gr.HTML("<center><h3>MoveNet: Ultra fast and accurate pose detection model</h3></center>")
with gr.Row().style(equal_height=False):
with gr.Column():
input_img = gr.Image(type="pil", label="Input image")
run_btn = gr.Button(variant="primary")
with gr.Column():
output_img = gr.Image(type="numpy", label="Output image")
with gr.Accordion("See Positions", open=False):
positions = gr.Dataframe(
interactive=True,
headers=["x", "y", "label"],
datatype=["str", "number", "number"],
row_count=16,
col_count=(3, "fixed"),
)
data = gr.Textbox(label="Positions", lines=17)
gr.ClearButton(components=[input_img, output_img, positions, data], variant="stop")
run_btn.click(predict, [input_img], [output_img, positions, data])
with gr.Row():
blobs = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
examples = gr.Dataset(components=[input_img], samples=blobs)
examples.click(lambda x: x[0], [examples], [input_img])
with gr.Row():
gr.HTML(footer)
app.launch(share=False, debug=True, show_error=True)
app.queue()