d22cs051's picture
Update app.py
19d6989
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
import PIL
from matplotlib import pyplot as plt
from timeit import default_timer as timer
from typing import Tuple, Dict
from models import get_detr, get_maskformer
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
### 2. Model and transforms preparation ###
# Create model
model_name_to_fn = {
"detr": get_detr,
"maskformer": get_maskformer,
}
### 3. Predict function ###
def fig2img(fig):
"""Convert a Matplotlib figure to a PIL Image and return it"""
import io
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
# Create predict function
def predict(image, model_name: str = "detr",) -> Tuple[Dict, float]:
"""
Desc: Transforms and performs a prediction on img and returns prediction and time taken.
Args:
model_name (str): Name of the model to use for prediction.
img (PIL.Image): Image to perform prediction on.
Returns:
Tuple[Image, float]: Tuple containing a dictionary of prediction labels and probabilities and the time taken to perform the prediction.
"""
# Start the timer
start_time = timer()
# Get the model function based on the model name
model_fn = model_name_to_fn[model_name]
# Create the model and load its weights
model,processor = model_fn()
model = model.to(device)
# Put model into evaluation mode and turn on inference mode
model.eval()
if model_name == "detr":
inputs = processor(images=image, return_tensors="pt")
inputs = inputs.to(device)
# forward pass
outputs = model(**inputs)
print("Output Generated!")
# Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
# Segmentation results are returned as a list of dictionaries
result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[(image.height, image.width)])
print("Output Post Processing Done!")
# print(f"result: {result[0].keys()}")
# A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
panoptic_seg = result[0]["segmentation"]
# Convert the tensor to PIL image
plt.plot(panoptic_seg, cmap="viridis")
# plt.imsave("predicted_panoptic_map.png", panoptic_seg, cmap="viridis")
fig = plt.gcf()
output = fig2img(fig)
# output = PIL.Image.open("predicted_panoptic_map.png")
# output = PIL.Image.fromarray(panoptic_seg.cpu().numpy().astype('uint8')).convert('RGB')
elif model_name == "maskformer":
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# model predicts class_queries_logits of shape `(batch_size, num_queries)`
# and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
class_queries_logits = outputs.class_queries_logits
masks_queries_logits = outputs.masks_queries_logits
# you can pass them to feature_extractor for postprocessing
result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
# we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
predicted_panoptic_map = result["segmentation"]
plt.plot(predicted_panoptic_map, cmap="viridis")
# plt.imsave("predicted_panoptic_map.png", predicted_panoptic_map, cmap="viridis")
fig = plt.gcf()
output = fig2img(fig)
# output = PIL.Image.open("predicted_panoptic_map.png")
# output = PIL.Image.fromarray(predicted_panoptic_map.cpu().numpy().astype('uint8')).convert('RGB')
# Calculate the prediction time
pred_time = round(timer() - start_time, 5)
# Return the prediction dictionary and prediction time
print("Returning Results!")
return output, pred_time
### 4. Gradio app ###
# Create title, description and article strings
title = "Segmentation Demo"
description = "An Mutimodel Segmentation Demo"
article = ""
# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create the Gradio demo
model_selection_dropdown = gr.components.Dropdown(
choices=list(model_name_to_fn.keys()),
label="Select a model",
value="detr"
)
demo = gr.Interface(
fn=predict, # mapping function from input to output
inputs=[gr.Image(type="pil"),model_selection_dropdown], # what are the inputs?
outputs=[
gr.Image(label="Mask"), # what are the outputs?
gr.Number(label="Prediction time (s)"),
], # our fn has two outputs, therefore we have two outputs
# Create examples list from "examples/" directory
examples=example_list,
title=title,
description=description,
article=article,
)
# Launch the demo!
demo.launch(
# debug=True,
# server_port=7860,
# server_name="0.0.0.0"
)