### 1. Imports and class names setup ### import gradio as gr import os import torch import PIL from matplotlib import pyplot as plt from timeit import default_timer as timer from typing import Tuple, Dict from models import get_detr, get_maskformer # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ### 2. Model and transforms preparation ### # Create model model_name_to_fn = { "detr": get_detr, "maskformer": get_maskformer, } ### 3. Predict function ### def fig2img(fig): """Convert a Matplotlib figure to a PIL Image and return it""" import io buf = io.BytesIO() fig.savefig(buf) buf.seek(0) img = Image.open(buf) return img # Create predict function def predict(image, model_name: str = "detr",) -> Tuple[Dict, float]: """ Desc: Transforms and performs a prediction on img and returns prediction and time taken. Args: model_name (str): Name of the model to use for prediction. img (PIL.Image): Image to perform prediction on. Returns: Tuple[Image, float]: Tuple containing a dictionary of prediction labels and probabilities and the time taken to perform the prediction. """ # Start the timer start_time = timer() # Get the model function based on the model name model_fn = model_name_to_fn[model_name] # Create the model and load its weights model,processor = model_fn() model = model.to(device) # Put model into evaluation mode and turn on inference mode model.eval() if model_name == "detr": inputs = processor(images=image, return_tensors="pt") inputs = inputs.to(device) # forward pass outputs = model(**inputs) print("Output Generated!") # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps # Segmentation results are returned as a list of dictionaries result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[(image.height, image.width)]) print("Output Post Processing Done!") # print(f"result: {result[0].keys()}") # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found panoptic_seg = result[0]["segmentation"] # Convert the tensor to PIL image plt.plot(panoptic_seg, cmap="viridis") # plt.imsave("predicted_panoptic_map.png", panoptic_seg, cmap="viridis") fig = plt.gcf() output = fig2img(fig) # output = PIL.Image.open("predicted_panoptic_map.png") # output = PIL.Image.fromarray(panoptic_seg.cpu().numpy().astype('uint8')).convert('RGB') elif model_name == "maskformer": inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) # model predicts class_queries_logits of shape `(batch_size, num_queries)` # and masks_queries_logits of shape `(batch_size, num_queries, height, width)` class_queries_logits = outputs.class_queries_logits masks_queries_logits = outputs.masks_queries_logits # you can pass them to feature_extractor for postprocessing result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] # we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs) predicted_panoptic_map = result["segmentation"] plt.plot(predicted_panoptic_map, cmap="viridis") # plt.imsave("predicted_panoptic_map.png", predicted_panoptic_map, cmap="viridis") fig = plt.gcf() output = fig2img(fig) # output = PIL.Image.open("predicted_panoptic_map.png") # output = PIL.Image.fromarray(predicted_panoptic_map.cpu().numpy().astype('uint8')).convert('RGB') # Calculate the prediction time pred_time = round(timer() - start_time, 5) # Return the prediction dictionary and prediction time print("Returning Results!") return output, pred_time ### 4. Gradio app ### # Create title, description and article strings title = "Segmentation Demo" description = "An Mutimodel Segmentation Demo" article = "" # Create examples list from "examples/" directory example_list = [["examples/" + example] for example in os.listdir("examples")] # Create the Gradio demo model_selection_dropdown = gr.components.Dropdown( choices=list(model_name_to_fn.keys()), label="Select a model", value="detr" ) demo = gr.Interface( fn=predict, # mapping function from input to output inputs=[gr.Image(type="pil"),model_selection_dropdown], # what are the inputs? outputs=[ gr.Image(label="Mask"), # what are the outputs? gr.Number(label="Prediction time (s)"), ], # our fn has two outputs, therefore we have two outputs # Create examples list from "examples/" directory examples=example_list, title=title, description=description, article=article, ) # Launch the demo! demo.launch( # debug=True, # server_port=7860, # server_name="0.0.0.0" )