gatesla's picture
Testing if maskformer is working
3a9ef72 verified
raw
history blame
9.39 kB
import io
import gradio as gr
import matplotlib.pyplot as plt
import requests, validators
import torch
import pathlib
from PIL import Image
from transformers import DetrFeatureExtractor, DetrForSegmentation
from transformers.models.detr.feature_extraction_detr import rgb_to_id
import os
def detect_objects(model_name,url_input,image_input,threshold):
if 'maskformer' in model_name:
if validators.url(url_input):
image = Image.open(requests.get(url_input, stream=True).raw)
tb_label = "Confidence Values URL"
elif image_input:
image = image_input
tb_label = "Confidence Values Upload"
# NOTE: Pulling from the example on https://huggingface.co/facebook/maskformer-swin-large-coco
# and https://huggingface.co/spaces/ajcdp/Image-Segmentation-Gradio/blob/main/app.py
processor = MaskFormerImageProcessor.from_pretrained(model_name)
model = MaskFormerForInstanceSegmentation.from_pretrained(model_name)
target_size = (img.shape[0], img.shape[1])
inputs = preprocessor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
outputs.class_queries_logits = outputs.class_queries_logits.cpu()
outputs.masks_queries_logits = outputs.masks_queries_logits.cpu()
results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach()
results = torch.argmax(results, dim=0).numpy()
results = visualize_instance_seg_mask(results)
return results, "EMPTY"
# for result in results:
# boxes = result.boxes.cpu().numpy()
# for i, box in enumerate(boxes):
# # r = box.xyxy[0].astype(int)
# coordinates = box.xyxy[0].astype(int)
# try:
# label = YOLOV8_LABELS[int(box.cls)]
# except:
# label = "ERROR"
# try:
# confi = float(box.conf)
# except:
# confi = 0.0
# # final_str_abv += str() + "__" + str(box.cls) + "__" + str(box.conf) + "__" + str(box) + "\n"
# if confi >= threshold:
# final_str_abv += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n"
# else:
# final_str_else += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n"
# final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else
# return render, final_str
elif "detr" in model_name:
# NOTE: Using the example on https://huggingface.co/facebook/detr-resnet-50-panoptic
if validators.url(url_input):
image = Image.open(requests.get(url_input, stream=True).raw)
tb_label = "Confidence Values URL"
elif image_input:
image = image_input
tb_label = "Confidence Values Upload"
feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
model = DetrForSegmentation.from_pretrained(model_name)
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
# use the `post_process_panoptic` method of `DetrFeatureExtractor` to convert to COCO format
processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
# the segmentation is stored in a special-format png
panoptic_seg = Image.open(io.BytesIO(result["png_string"]))
panoptic_seg = numpy.array(panoptic_seg, dtype=numpy.uint8)
# retrieve the ids corresponding to each mask
panoptic_seg_id = rgb_to_id(panoptic_seg)
return gr.Image.update(), "EMPTY"
#Visualize prediction
viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
# return [viz_img, processed_outputs]
# print(type(viz_img))
final_str_abv = ""
final_str_else = ""
for score, label, box in sorted(zip(processed_outputs["scores"], processed_outputs["labels"], processed_outputs["boxes"]), key = lambda x: x[0].item(), reverse=True):
box = [round(i, 2) for i in box.tolist()]
if score.item() >= threshold:
final_str_abv += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n"
else:
final_str_else += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n"
# https://docs.python.org/3/library/string.html#format-examples
final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else
return viz_img, final_str
else:
raise NameError(f"Model name {model_name} not prepared")
def set_example_image(example: list) -> dict:
return gr.Image.update(value=example[0])
def set_example_url(example: list) -> dict:
return gr.Textbox.update(value=example[0])
title = """<h1 id="title">Object Detection App with DETR and YOLOS</h1>"""
description = """
Links to HuggingFace Models:
- [facebook/detr-resnet-50-panoptic](https://huggingface.co/facebook/detr-resnet-50-panoptic)
- [facebook/detr-resnet-101-panoptic](https://huggingface.co/facebook/detr-resnet-101-panoptic)
- [facebook/maskformer-swin-large-coco](https://huggingface.co/facebook/maskformer-swin-large-coco)
"""
models = ["facebook/detr-resnet-50-panoptic","facebook/detr-resnet-101-panoptic","facebook/maskformer-swin-large-coco"]
urls = ["https://c8.alamy.com/comp/J2AB4K/the-new-york-stock-exchange-on-the-wall-street-in-new-york-J2AB4K.jpg"]
# twitter_link = """
# [![](https://img.shields.io/twitter/follow/nickmuchi?label=@nickmuchi&style=social)](https://twitter.com/nickmuchi)
# """
css = '''
h1#title {
text-align: center;
}
'''
demo = gr.Blocks(css=css)
def changing():
# https://discuss.huggingface.co/t/how-to-programmatically-enable-or-disable-components/52350/4
return gr.Button.update(interactive=True), gr.Button.update(interactive=True)
with demo:
gr.Markdown(title)
gr.Markdown(description)
# gr.Markdown(twitter_link)
options = gr.Dropdown(choices=models,label='Select Image Segmentation Model',show_label=True)
slider_input = gr.Slider(minimum=0.2,maximum=1,value=0.7,label='Prediction Threshold')
with gr.Tabs():
with gr.TabItem('Image URL'):
with gr.Row():
url_input = gr.Textbox(lines=2,label='Enter valid image URL here..')
img_output_from_url = gr.Image(shape=(650,650))
with gr.Row():
example_url = gr.Dataset(components=[url_input],samples=[[str(url)] for url in urls])
url_but = gr.Button('Detect', interactive=False)
with gr.TabItem('Image Upload'):
with gr.Row():
img_input = gr.Image(type='pil')
img_output_from_upload= gr.Image(shape=(650,650))
with gr.Row():
example_images = gr.Dataset(components=[img_input],
samples=[[path.as_posix()]
for path in sorted(pathlib.Path('images').rglob('*.JPG'))]) # Can't get case_sensitive to work
img_but = gr.Button('Detect', interactive=False)
# output_text1 = gr.outputs.Textbox(label="Confidence Values")
output_text1 = gr.components.Textbox(label="Confidence Values")
# https://huggingface.co/spaces/vishnun/CLIPnCROP/blob/main/app.py -- Got .outputs. from this
options.change(fn=changing, inputs=[], outputs=[img_but, url_but])
url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_url, output_text1],queue=True)
img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_upload, output_text1],queue=True)
# url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_url, _],queue=True)
# img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_upload, _],queue=True)
# url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_url,queue=True)
# img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_upload,queue=True)
example_images.click(fn=set_example_image,inputs=[example_images],outputs=[img_input])
example_url.click(fn=set_example_url,inputs=[example_url],outputs=[url_input])
# gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-object-detection-with-detr-and-yolos)")
# demo.launch(enable_queue=True)
demo.launch() #removed (share=True)