Spaces:
Runtime error
Runtime error
### 1. Imports and class names setup ### | |
import gradio as gr | |
import os | |
import torch | |
import PIL | |
from matplotlib import pyplot as plt | |
from timeit import default_timer as timer | |
from typing import Tuple, Dict | |
from models import get_detr, get_maskformer | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
### 2. Model and transforms preparation ### | |
# Create model | |
model_name_to_fn = { | |
"detr": get_detr, | |
"maskformer": get_maskformer, | |
} | |
### 3. Predict function ### | |
def fig2img(fig): | |
"""Convert a Matplotlib figure to a PIL Image and return it""" | |
import io | |
buf = io.BytesIO() | |
fig.savefig(buf) | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
# Create predict function | |
def predict(image, model_name: str = "detr",) -> Tuple[Dict, float]: | |
""" | |
Desc: Transforms and performs a prediction on img and returns prediction and time taken. | |
Args: | |
model_name (str): Name of the model to use for prediction. | |
img (PIL.Image): Image to perform prediction on. | |
Returns: | |
Tuple[Image, float]: Tuple containing a dictionary of prediction labels and probabilities and the time taken to perform the prediction. | |
""" | |
# Start the timer | |
start_time = timer() | |
# Get the model function based on the model name | |
model_fn = model_name_to_fn[model_name] | |
# Create the model and load its weights | |
model,processor = model_fn() | |
model = model.to(device) | |
# Put model into evaluation mode and turn on inference mode | |
model.eval() | |
if model_name == "detr": | |
inputs = processor(images=image, return_tensors="pt") | |
inputs = inputs.to(device) | |
# forward pass | |
outputs = model(**inputs) | |
print("Output Generated!") | |
# Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps | |
# Segmentation results are returned as a list of dictionaries | |
result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[(image.height, image.width)]) | |
print("Output Post Processing Done!") | |
# print(f"result: {result[0].keys()}") | |
# A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found | |
panoptic_seg = result[0]["segmentation"] | |
# Convert the tensor to PIL image | |
plt.plot(panoptic_seg, cmap="viridis") | |
# plt.imsave("predicted_panoptic_map.png", panoptic_seg, cmap="viridis") | |
fig = plt.gcf() | |
output = fig2img(fig) | |
# output = PIL.Image.open("predicted_panoptic_map.png") | |
# output = PIL.Image.fromarray(panoptic_seg.cpu().numpy().astype('uint8')).convert('RGB') | |
elif model_name == "maskformer": | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
# model predicts class_queries_logits of shape `(batch_size, num_queries)` | |
# and masks_queries_logits of shape `(batch_size, num_queries, height, width)` | |
class_queries_logits = outputs.class_queries_logits | |
masks_queries_logits = outputs.masks_queries_logits | |
# you can pass them to feature_extractor for postprocessing | |
result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] | |
# we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs) | |
predicted_panoptic_map = result["segmentation"] | |
plt.plot(predicted_panoptic_map, cmap="viridis") | |
# plt.imsave("predicted_panoptic_map.png", predicted_panoptic_map, cmap="viridis") | |
fig = plt.gcf() | |
output = fig2img(fig) | |
# output = PIL.Image.open("predicted_panoptic_map.png") | |
# output = PIL.Image.fromarray(predicted_panoptic_map.cpu().numpy().astype('uint8')).convert('RGB') | |
# Calculate the prediction time | |
pred_time = round(timer() - start_time, 5) | |
# Return the prediction dictionary and prediction time | |
print("Returning Results!") | |
return output, pred_time | |
### 4. Gradio app ### | |
# Create title, description and article strings | |
title = "Segmentation Demo" | |
description = "An Mutimodel Segmentation Demo" | |
article = "" | |
# Create examples list from "examples/" directory | |
example_list = [["examples/" + example] for example in os.listdir("examples")] | |
# Create the Gradio demo | |
model_selection_dropdown = gr.components.Dropdown( | |
choices=list(model_name_to_fn.keys()), | |
label="Select a model", | |
value="detr" | |
) | |
demo = gr.Interface( | |
fn=predict, # mapping function from input to output | |
inputs=[gr.Image(type="pil"),model_selection_dropdown], # what are the inputs? | |
outputs=[ | |
gr.Image(label="Mask"), # what are the outputs? | |
gr.Number(label="Prediction time (s)"), | |
], # our fn has two outputs, therefore we have two outputs | |
# Create examples list from "examples/" directory | |
examples=example_list, | |
title=title, | |
description=description, | |
article=article, | |
) | |
# Launch the demo! | |
demo.launch( | |
# debug=True, | |
# server_port=7860, | |
# server_name="0.0.0.0" | |
) |