SuperPoint / app.py
merve's picture
merve HF staff
Create app.py
282dd93 verified
raw
history blame
1.3 kB
from transformers import AutoImageProcessor, SuperPointForKeypointDetection
import torch
import matplotlib.pyplot as plt
import uuid
import gradio as gr
processor = AutoImageProcessor.from_pretrained("magic-leap-community/superpoint")
model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint")
def infer(image):
inputs = processor(image,return_tensors="pt").to(model.device, model.dtype)
model_outputs = model(**inputs)
image_sizes = [(image.size[1], image.size[0])]
outputs = processor.post_process_keypoint_detection(model_outputs, image_sizes)
keypoints = outputs[0]["keypoints"].detach().numpy()
scores = outputs[0]["scores"].detach().numpy()
image_width, image_height = image.size
plt.axis('off')
plt.imshow(image)
plt.scatter(
keypoints[:, 0],
keypoints[:, 1],
s=scores * 100,
c='cyan',
alpha=0.4
)
path = "./" + uuid.uuid4().hex + ".png"
plt.savefig(path)
return path
title = "## SuperPoint"
description = "Try [SuperPoint] in this demo, foundation model for keypoint detection supported in 🤗 transformers. Simply upload an image or try the example."
iface = gr.Interface(fn = infer, inputs = gr.Image(type="pil"),
outputs = gr.Image(), examples=["./bee.jpg"])
iface.launch()