import io import gradio as gr import matplotlib.pyplot as plt import requests, validators import torch import pathlib from PIL import Image from transformers import DetrFeatureExtractor, DetrForSegmentation, MaskFormerImageProcessor, MaskFormerForInstanceSegmentation from transformers.models.detr.feature_extraction_detr import rgb_to_id import os def detect_objects(model_name,url_input,image_input,threshold): if 'maskformer' in model_name: if validators.url(url_input): image = Image.open(requests.get(url_input, stream=True).raw) tb_label = "Confidence Values URL" elif image_input: image = image_input tb_label = "Confidence Values Upload" # NOTE: Pulling from the example on https://huggingface.co/facebook/maskformer-swin-large-coco # and https://huggingface.co/spaces/ajcdp/Image-Segmentation-Gradio/blob/main/app.py processor = MaskFormerImageProcessor.from_pretrained(model_name) model = MaskFormerForInstanceSegmentation.from_pretrained(model_name) target_size = (image.shape[0], image.shape[1]) inputs = preprocessor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) outputs.class_queries_logits = outputs.class_queries_logits.cpu() outputs.masks_queries_logits = outputs.masks_queries_logits.cpu() results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach() results = torch.argmax(results, dim=0).numpy() results = visualize_instance_seg_mask(results) return results, "EMPTY" # for result in results: # boxes = result.boxes.cpu().numpy() # for i, box in enumerate(boxes): # # r = box.xyxy[0].astype(int) # coordinates = box.xyxy[0].astype(int) # try: # label = YOLOV8_LABELS[int(box.cls)] # except: # label = "ERROR" # try: # confi = float(box.conf) # except: # confi = 0.0 # # final_str_abv += str() + "__" + str(box.cls) + "__" + str(box.conf) + "__" + str(box) + "\n" # if confi >= threshold: # final_str_abv += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n" # else: # final_str_else += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n" # final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else # return render, final_str elif "detr" in model_name: # NOTE: Using the example on https://huggingface.co/facebook/detr-resnet-50-panoptic if validators.url(url_input): image = Image.open(requests.get(url_input, stream=True).raw) tb_label = "Confidence Values URL" elif image_input: image = image_input tb_label = "Confidence Values Upload" feature_extractor = DetrFeatureExtractor.from_pretrained(model_name) model = DetrForSegmentation.from_pretrained(model_name) inputs = feature_extractor(images=image, return_tensors="pt") outputs = model(**inputs) # use the `post_process_panoptic` method of `DetrFeatureExtractor` to convert to COCO format processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0) result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0] # the segmentation is stored in a special-format png panoptic_seg = Image.open(io.BytesIO(result["png_string"])) panoptic_seg = numpy.array(panoptic_seg, dtype=numpy.uint8) # retrieve the ids corresponding to each mask panoptic_seg_id = rgb_to_id(panoptic_seg) return gr.Image.update(), "EMPTY" #Visualize prediction viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label) # return [viz_img, processed_outputs] # print(type(viz_img)) final_str_abv = "" final_str_else = "" for score, label, box in sorted(zip(processed_outputs["scores"], processed_outputs["labels"], processed_outputs["boxes"]), key = lambda x: x[0].item(), reverse=True): box = [round(i, 2) for i in box.tolist()] if score.item() >= threshold: final_str_abv += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n" else: final_str_else += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n" # https://docs.python.org/3/library/string.html#format-examples final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else return viz_img, final_str else: raise NameError(f"Model name {model_name} not prepared") def set_example_image(example: list) -> dict: return gr.Image.update(value=example[0]) def set_example_url(example: list) -> dict: return gr.Textbox.update(value=example[0]) title = """

Image Segmentation with Various Models

""" description = """ Links to HuggingFace Models: - [facebook/detr-resnet-50-panoptic](https://huggingface.co/facebook/detr-resnet-50-panoptic) - [facebook/detr-resnet-101-panoptic](https://huggingface.co/facebook/detr-resnet-101-panoptic) - [facebook/maskformer-swin-large-coco](https://huggingface.co/facebook/maskformer-swin-large-coco) """ models = ["facebook/detr-resnet-50-panoptic","facebook/detr-resnet-101-panoptic","facebook/maskformer-swin-large-coco"] urls = ["https://c8.alamy.com/comp/J2AB4K/the-new-york-stock-exchange-on-the-wall-street-in-new-york-J2AB4K.jpg"] # twitter_link = """ # [![](https://img.shields.io/twitter/follow/nickmuchi?label=@nickmuchi&style=social)](https://twitter.com/nickmuchi) # """ css = ''' h1#title { text-align: center; } ''' demo = gr.Blocks(css=css) def changing(): # https://discuss.huggingface.co/t/how-to-programmatically-enable-or-disable-components/52350/4 return gr.Button.update(interactive=True), gr.Button.update(interactive=True) with demo: gr.Markdown(title) gr.Markdown(description) # gr.Markdown(twitter_link) options = gr.Dropdown(choices=models,label='Select Image Segmentation Model',show_label=True) slider_input = gr.Slider(minimum=0.2,maximum=1,value=0.7,label='Prediction Threshold') with gr.Tabs(): with gr.TabItem('Image URL'): with gr.Row(): url_input = gr.Textbox(lines=2,label='Enter valid image URL here..') img_output_from_url = gr.Image(shape=(650,650)) with gr.Row(): example_url = gr.Dataset(components=[url_input],samples=[[str(url)] for url in urls]) url_but = gr.Button('Detect', interactive=False) with gr.TabItem('Image Upload'): with gr.Row(): img_input = gr.Image(type='pil') img_output_from_upload= gr.Image(shape=(650,650)) with gr.Row(): example_images = gr.Dataset(components=[img_input], samples=[[path.as_posix()] for path in sorted(pathlib.Path('images').rglob('*.JPG'))]) # Can't get case_sensitive to work img_but = gr.Button('Detect', interactive=False) # output_text1 = gr.outputs.Textbox(label="Confidence Values") output_text1 = gr.components.Textbox(label="Confidence Values") # https://huggingface.co/spaces/vishnun/CLIPnCROP/blob/main/app.py -- Got .outputs. from this options.change(fn=changing, inputs=[], outputs=[img_but, url_but]) url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_url, output_text1],queue=True) img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_upload, output_text1],queue=True) # url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_url, _],queue=True) # img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_upload, _],queue=True) # url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_url,queue=True) # img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_upload,queue=True) example_images.click(fn=set_example_image,inputs=[example_images],outputs=[img_input]) example_url.click(fn=set_example_url,inputs=[example_url],outputs=[url_input]) # gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-object-detection-with-detr-and-yolos)") # demo.launch(enable_queue=True) demo.launch() #removed (share=True)