Spaces:
Runtime error
Runtime error
base app uplaod
Browse files- .gitattributes +3 -0
- app.py +135 -0
- examples/IMG_6093.JPG +3 -0
- examples/IMG_6111.JPG +3 -0
- examples/IMG_7047.JPG +3 -0
- models.py +14 -0
- requirements.txt +4 -0
.gitattributes
CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
examples/IMG_6093.JPG filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/IMG_6111.JPG filter=lfs diff=lfs merge=lfs -text
|
37 |
+
examples/IMG_7047.JPG filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 1. Imports and class names setup ###
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import PIL
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
|
8 |
+
from timeit import default_timer as timer
|
9 |
+
from typing import Tuple, Dict
|
10 |
+
|
11 |
+
from models import get_detr, get_maskformer
|
12 |
+
|
13 |
+
# Set device
|
14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
|
16 |
+
|
17 |
+
### 2. Model and transforms preparation ###
|
18 |
+
|
19 |
+
# Create model
|
20 |
+
|
21 |
+
model_name_to_fn = {
|
22 |
+
"detr": get_detr,
|
23 |
+
"maskformer": get_maskformer,
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
### 3. Predict function ###
|
28 |
+
|
29 |
+
|
30 |
+
# Create predict function
|
31 |
+
def predict(image, model_name: str = "detr",) -> Tuple[Dict, float]:
|
32 |
+
"""
|
33 |
+
Desc: Transforms and performs a prediction on img and returns prediction and time taken.
|
34 |
+
Args:
|
35 |
+
model_name (str): Name of the model to use for prediction.
|
36 |
+
img (PIL.Image): Image to perform prediction on.
|
37 |
+
Returns:
|
38 |
+
Tuple[Image, float]: Tuple containing a dictionary of prediction labels and probabilities and the time taken to perform the prediction.
|
39 |
+
"""
|
40 |
+
# Start the timer
|
41 |
+
start_time = timer()
|
42 |
+
|
43 |
+
# Get the model function based on the model name
|
44 |
+
model_fn = model_name_to_fn[model_name]
|
45 |
+
|
46 |
+
# Create the model and load its weights
|
47 |
+
model,processor = model_fn()
|
48 |
+
model = model.to(device)
|
49 |
+
|
50 |
+
|
51 |
+
# Put model into evaluation mode and turn on inference mode
|
52 |
+
model.eval()
|
53 |
+
|
54 |
+
if model_name == "detr":
|
55 |
+
inputs = processor(images=image, return_tensors="pt")
|
56 |
+
inputs = inputs.to(device)
|
57 |
+
# forward pass
|
58 |
+
outputs = model(**inputs)
|
59 |
+
print("Output Generated!")
|
60 |
+
|
61 |
+
# Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
|
62 |
+
# Segmentation results are returned as a list of dictionaries
|
63 |
+
result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[(image.height, image.width)])
|
64 |
+
print("Output Post Processing Done!")
|
65 |
+
# print(f"result: {result[0].keys()}")
|
66 |
+
|
67 |
+
# A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
|
68 |
+
panoptic_seg = result[0]["segmentation"]
|
69 |
+
# Convert the tensor to PIL image
|
70 |
+
plt.imsave("predicted_panoptic_map.png", panoptic_seg, cmap="viridis")
|
71 |
+
output = PIL.Image.open("predicted_panoptic_map.png")
|
72 |
+
# output = PIL.Image.fromarray(panoptic_seg.cpu().numpy().astype('uint8')).convert('RGB')
|
73 |
+
|
74 |
+
elif model_name == "maskformer":
|
75 |
+
inputs = processor(images=image, return_tensors="pt")
|
76 |
+
|
77 |
+
outputs = model(**inputs)
|
78 |
+
# model predicts class_queries_logits of shape `(batch_size, num_queries)`
|
79 |
+
# and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
|
80 |
+
class_queries_logits = outputs.class_queries_logits
|
81 |
+
masks_queries_logits = outputs.masks_queries_logits
|
82 |
+
|
83 |
+
# you can pass them to feature_extractor for postprocessing
|
84 |
+
result = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
85 |
+
# we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
|
86 |
+
predicted_panoptic_map = result["segmentation"]
|
87 |
+
plt.imsave("predicted_panoptic_map.png", predicted_panoptic_map, cmap="viridis")
|
88 |
+
output = PIL.Image.open("predicted_panoptic_map.png")
|
89 |
+
# output = PIL.Image.fromarray(predicted_panoptic_map.cpu().numpy().astype('uint8')).convert('RGB')
|
90 |
+
|
91 |
+
# Calculate the prediction time
|
92 |
+
pred_time = round(timer() - start_time, 5)
|
93 |
+
|
94 |
+
# Return the prediction dictionary and prediction time
|
95 |
+
print("Returning Results!")
|
96 |
+
return output, pred_time
|
97 |
+
|
98 |
+
|
99 |
+
### 4. Gradio app ###
|
100 |
+
|
101 |
+
# Create title, description and article strings
|
102 |
+
title = "Segementation Demo"
|
103 |
+
description = "An Mutimodel Segementation Demo"
|
104 |
+
article = ""
|
105 |
+
|
106 |
+
# Create examples list from "examples/" directory
|
107 |
+
example_list = [["examples/" + example] for example in os.listdir("examples")]
|
108 |
+
|
109 |
+
# Create the Gradio demo
|
110 |
+
model_selection_dropdown = gr.components.Dropdown(
|
111 |
+
choices=list(model_name_to_fn.keys()),
|
112 |
+
label="Select a model",
|
113 |
+
value="detr"
|
114 |
+
)
|
115 |
+
|
116 |
+
demo = gr.Interface(
|
117 |
+
fn=predict, # mapping function from input to output
|
118 |
+
inputs=[gr.Image(type="pil"),model_selection_dropdown], # what are the inputs?
|
119 |
+
outputs=[
|
120 |
+
gr.Image(label="Mask"), # what are the outputs?
|
121 |
+
gr.Number(label="Prediction time (s)"),
|
122 |
+
], # our fn has two outputs, therefore we have two outputs
|
123 |
+
# Create examples list from "examples/" directory
|
124 |
+
examples=example_list,
|
125 |
+
title=title,
|
126 |
+
description=description,
|
127 |
+
article=article,
|
128 |
+
)
|
129 |
+
|
130 |
+
# Launch the demo!
|
131 |
+
demo.launch(
|
132 |
+
debug=True,
|
133 |
+
server_port=7860,
|
134 |
+
server_name="0.0.0.0"
|
135 |
+
)
|
examples/IMG_6093.JPG
ADDED
|
Git LFS Details
|
examples/IMG_6111.JPG
ADDED
|
Git LFS Details
|
examples/IMG_7047.JPG
ADDED
|
Git LFS Details
|
models.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoImageProcessor, DetrForSegmentation
|
2 |
+
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
|
3 |
+
|
4 |
+
|
5 |
+
def get_detr():
|
6 |
+
image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
|
7 |
+
model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")
|
8 |
+
return model, image_processor
|
9 |
+
|
10 |
+
|
11 |
+
def get_maskformer():
|
12 |
+
feature_extractor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-small-coco")
|
13 |
+
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
|
14 |
+
return model, feature_extractor
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
gradio
|
4 |
+
'transformers[torch]'
|