Spaces:
Runtime error
Runtime error
File size: 5,149 Bytes
aa1f5e1 019307d aa1f5e1 19d6989 019307d aa1f5e1 19d6989 019307d aa1f5e1 019307d aa1f5e1 0243979 aa1f5e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
import PIL
from matplotlib import pyplot as plt
from timeit import default_timer as timer
from typing import Tuple, Dict
from models import get_detr, get_maskformer
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
### 2. Model and transforms preparation ###
# Create model
model_name_to_fn = {
"detr": get_detr,
"maskformer": get_maskformer,
}
### 3. Predict function ###
def fig2img(fig):
"""Convert a Matplotlib figure to a PIL Image and return it"""
import io
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
# Create predict function
def predict(image, model_name: str = "detr",) -> Tuple[Dict, float]:
"""
Desc: Transforms and performs a prediction on img and returns prediction and time taken.
Args:
model_name (str): Name of the model to use for prediction.
img (PIL.Image): Image to perform prediction on.
Returns:
Tuple[Image, float]: Tuple containing a dictionary of prediction labels and probabilities and the time taken to perform the prediction.
"""
# Start the timer
start_time = timer()
# Get the model function based on the model name
model_fn = model_name_to_fn[model_name]
# Create the model and load its weights
model,processor = model_fn()
model = model.to(device)
# Put model into evaluation mode and turn on inference mode
model.eval()
if model_name == "detr":
inputs = processor(images=image, return_tensors="pt")
inputs = inputs.to(device)
# forward pass
outputs = model(**inputs)
print("Output Generated!")
# Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
# Segmentation results are returned as a list of dictionaries
result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[(image.height, image.width)])
print("Output Post Processing Done!")
# print(f"result: {result[0].keys()}")
# A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
panoptic_seg = result[0]["segmentation"]
# Convert the tensor to PIL image
plt.plot(panoptic_seg, cmap="viridis")
# plt.imsave("predicted_panoptic_map.png", panoptic_seg, cmap="viridis")
fig = plt.gcf()
output = fig2img(fig)
# output = PIL.Image.open("predicted_panoptic_map.png")
# output = PIL.Image.fromarray(panoptic_seg.cpu().numpy().astype('uint8')).convert('RGB')
elif model_name == "maskformer":
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# model predicts class_queries_logits of shape `(batch_size, num_queries)`
# and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
class_queries_logits = outputs.class_queries_logits
masks_queries_logits = outputs.masks_queries_logits
# you can pass them to feature_extractor for postprocessing
result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
# we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
predicted_panoptic_map = result["segmentation"]
plt.plot(predicted_panoptic_map, cmap="viridis")
# plt.imsave("predicted_panoptic_map.png", predicted_panoptic_map, cmap="viridis")
fig = plt.gcf()
output = fig2img(fig)
# output = PIL.Image.open("predicted_panoptic_map.png")
# output = PIL.Image.fromarray(predicted_panoptic_map.cpu().numpy().astype('uint8')).convert('RGB')
# Calculate the prediction time
pred_time = round(timer() - start_time, 5)
# Return the prediction dictionary and prediction time
print("Returning Results!")
return output, pred_time
### 4. Gradio app ###
# Create title, description and article strings
title = "Segmentation Demo"
description = "An Mutimodel Segmentation Demo"
article = ""
# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create the Gradio demo
model_selection_dropdown = gr.components.Dropdown(
choices=list(model_name_to_fn.keys()),
label="Select a model",
value="detr"
)
demo = gr.Interface(
fn=predict, # mapping function from input to output
inputs=[gr.Image(type="pil"),model_selection_dropdown], # what are the inputs?
outputs=[
gr.Image(label="Mask"), # what are the outputs?
gr.Number(label="Prediction time (s)"),
], # our fn has two outputs, therefore we have two outputs
# Create examples list from "examples/" directory
examples=example_list,
title=title,
description=description,
article=article,
)
# Launch the demo!
demo.launch(
# debug=True,
# server_port=7860,
# server_name="0.0.0.0"
) |