File size: 5,149 Bytes
aa1f5e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
019307d
 
 
 
 
 
 
 
aa1f5e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19d6989
 
019307d
 
 
aa1f5e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19d6989
 
019307d
 
 
aa1f5e1
 
 
 
 
 
 
 
 
 
 
 
 
019307d
 
aa1f5e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0243979
 
 
aa1f5e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
import PIL
from matplotlib import pyplot as plt

from timeit import default_timer as timer
from typing import Tuple, Dict

from models import get_detr, get_maskformer

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


### 2. Model and transforms preparation ###

# Create model

model_name_to_fn = {
    "detr": get_detr,
    "maskformer": get_maskformer,
}


### 3. Predict function ###
def fig2img(fig):
    """Convert a Matplotlib figure to a PIL Image and return it"""
    import io
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img


# Create predict function
def predict(image, model_name: str = "detr",) -> Tuple[Dict, float]:
    """
    Desc: Transforms and performs a prediction on img and returns prediction and time taken.
    Args:
        model_name (str): Name of the model to use for prediction.
        img (PIL.Image): Image to perform prediction on.
    Returns:
        Tuple[Image, float]: Tuple containing a dictionary of prediction labels and probabilities and the time taken to perform the prediction.
    """
    # Start the timer
    start_time = timer()

    # Get the model function based on the model name
    model_fn = model_name_to_fn[model_name]

    # Create the model and load its weights
    model,processor = model_fn()
    model = model.to(device)
    

    # Put model into evaluation mode and turn on inference mode
    model.eval()
    
    if model_name == "detr":
        inputs = processor(images=image, return_tensors="pt")
        inputs = inputs.to(device)
        # forward pass
        outputs = model(**inputs)
        print("Output Generated!")

        # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
        # Segmentation results are returned as a list of dictionaries
        result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[(image.height, image.width)])
        print("Output Post Processing Done!")
        # print(f"result: {result[0].keys()}")

        # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
        panoptic_seg = result[0]["segmentation"]
        # Convert the tensor to PIL image
        plt.plot(panoptic_seg, cmap="viridis")
        # plt.imsave("predicted_panoptic_map.png", panoptic_seg, cmap="viridis")
        fig = plt.gcf()
        output = fig2img(fig)
        # output = PIL.Image.open("predicted_panoptic_map.png")
        # output = PIL.Image.fromarray(panoptic_seg.cpu().numpy().astype('uint8')).convert('RGB')
    
    elif model_name == "maskformer":
        inputs = processor(images=image, return_tensors="pt")

        outputs = model(**inputs)
        # model predicts class_queries_logits of shape `(batch_size, num_queries)`
        # and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
        class_queries_logits = outputs.class_queries_logits
        masks_queries_logits = outputs.masks_queries_logits

        # you can pass them to feature_extractor for postprocessing
        result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
        # we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
        predicted_panoptic_map = result["segmentation"]
        plt.plot(predicted_panoptic_map, cmap="viridis")
        # plt.imsave("predicted_panoptic_map.png", predicted_panoptic_map, cmap="viridis")
        fig = plt.gcf()
        output = fig2img(fig)
        # output = PIL.Image.open("predicted_panoptic_map.png")
        # output = PIL.Image.fromarray(predicted_panoptic_map.cpu().numpy().astype('uint8')).convert('RGB')

    # Calculate the prediction time
    pred_time = round(timer() - start_time, 5)

    # Return the prediction dictionary and prediction time
    print("Returning Results!")
    return output, pred_time


### 4. Gradio app ###

# Create title, description and article strings
title = "Segmentation Demo"
description = "An Mutimodel Segmentation Demo"
article = ""

# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]

# Create the Gradio demo
model_selection_dropdown = gr.components.Dropdown(
    choices=list(model_name_to_fn.keys()),
    label="Select a model",
    value="detr"
)

demo = gr.Interface(
    fn=predict,  # mapping function from input to output
    inputs=[gr.Image(type="pil"),model_selection_dropdown],  # what are the inputs?
    outputs=[
        gr.Image(label="Mask"),  # what are the outputs?
        gr.Number(label="Prediction time (s)"),
    ],  # our fn has two outputs, therefore we have two outputs
    # Create examples list from "examples/" directory
    examples=example_list,
    title=title,
    description=description,
    article=article,
)

# Launch the demo!
demo.launch(
    # debug=True,
    # server_port=7860,
    # server_name="0.0.0.0"
)