import io import gradio as gr import matplotlib.pyplot as plt import requests, validators import torch import pathlib from PIL import Image from transformers import DetrFeatureExtractor, DetrForSegmentation, MaskFormerImageProcessor, MaskFormerForInstanceSegmentation from transformers.models.detr.feature_extraction_detr import rgb_to_id import os def detect_objects(model_name,url_input,image_input,threshold): if 'maskformer' in model_name: if validators.url(url_input): image = Image.open(requests.get(url_input, stream=True).raw) tb_label = "Confidence Values URL" elif image_input: image = image_input tb_label = "Confidence Values Upload" # NOTE: Pulling from the example on https://huggingface.co/facebook/maskformer-swin-large-coco # and https://huggingface.co/spaces/ajcdp/Image-Segmentation-Gradio/blob/main/app.py processor = MaskFormerImageProcessor.from_pretrained(model_name) model = MaskFormerForInstanceSegmentation.from_pretrained(model_name) target_size = (image.shape[0], image.shape[1]) inputs = preprocessor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) outputs.class_queries_logits = outputs.class_queries_logits.cpu() outputs.masks_queries_logits = outputs.masks_queries_logits.cpu() results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach() results = torch.argmax(results, dim=0).numpy() results = visualize_instance_seg_mask(results) return results, "EMPTY" # for result in results: # boxes = result.boxes.cpu().numpy() # for i, box in enumerate(boxes): # # r = box.xyxy[0].astype(int) # coordinates = box.xyxy[0].astype(int) # try: # label = YOLOV8_LABELS[int(box.cls)] # except: # label = "ERROR" # try: # confi = float(box.conf) # except: # confi = 0.0 # # final_str_abv += str() + "__" + str(box.cls) + "__" + str(box.conf) + "__" + str(box) + "\n" # if confi >= threshold: # final_str_abv += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n" # else: # final_str_else += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n" # final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else # return render, final_str elif "detr" in model_name: # NOTE: Using the example on https://huggingface.co/facebook/detr-resnet-50-panoptic if validators.url(url_input): image = Image.open(requests.get(url_input, stream=True).raw) tb_label = "Confidence Values URL" elif image_input: image = image_input tb_label = "Confidence Values Upload" feature_extractor = DetrFeatureExtractor.from_pretrained(model_name) model = DetrForSegmentation.from_pretrained(model_name) inputs = feature_extractor(images=image, return_tensors="pt") outputs = model(**inputs) # use the `post_process_panoptic` method of `DetrFeatureExtractor` to convert to COCO format processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0) result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0] # the segmentation is stored in a special-format png panoptic_seg = Image.open(io.BytesIO(result["png_string"])) panoptic_seg = numpy.array(panoptic_seg, dtype=numpy.uint8) # retrieve the ids corresponding to each mask panoptic_seg_id = rgb_to_id(panoptic_seg) return gr.Image.update(), "EMPTY" #Visualize prediction viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label) # return [viz_img, processed_outputs] # print(type(viz_img)) final_str_abv = "" final_str_else = "" for score, label, box in sorted(zip(processed_outputs["scores"], processed_outputs["labels"], processed_outputs["boxes"]), key = lambda x: x[0].item(), reverse=True): box = [round(i, 2) for i in box.tolist()] if score.item() >= threshold: final_str_abv += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n" else: final_str_else += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n" # https://docs.python.org/3/library/string.html#format-examples final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else return viz_img, final_str else: raise NameError(f"Model name {model_name} not prepared") def set_example_image(example: list) -> dict: return gr.Image.update(value=example[0]) def set_example_url(example: list) -> dict: return gr.Textbox.update(value=example[0]) title = """