Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,415 Bytes
0fa94d0 d7a5aaa 0fa94d0 d49c62f 0fa94d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import json
import random
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime
import torch
import torchvision.transforms.functional as F
from huggingface_hub import hf_hub_download
from PIL import Image, ImageColor
from torchvision.io import read_image
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
# Load pre-trained model transformations.
weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
transforms = weights.transforms()
def fix_category_id(cat_ids: list):
# Define the excluded category ids and the remaining ones
excluded_indices = {2, 12, 16, 19, 20}
remaining_categories = list(set(range(27)) - excluded_indices)
# Create a dictionary that maps new IDs to old(original) IDs
new_id_to_org_id = dict(zip(range(len(remaining_categories)), remaining_categories))
return [new_id_to_org_id[i-1]+1 for i in cat_ids]
def process_categories() -> tuple:
"""
Load and process category information from a JSON file.
Returns a tuple containing two dictionaries: `category_id_to_name` maps category IDs to their names, and
`category_id_to_color` maps category IDs to a randomly sampled RGB color.
Returns:
tuple: A tuple containing two dictionaries:
- `category_id_to_name`: a dictionary mapping category IDs to their names.
- `category_id_to_color`: a dictionary mapping category IDs to a randomly sampled RGB color.
"""
# Load raw categories from JSON file
with open("categories.json") as fp:
categories = json.load(fp)
# Map category IDs to names
category_id_to_name = {d["id"]: d["name"] for d in categories}
# Set the seed for the random sampling operation
random.seed(42)
# Get a list of all the color names in the PIL colormap
color_names = list(ImageColor.colormap.keys())
# Sample 46 unique colors from the list of color names
sampled_colors = random.sample(color_names, 46)
# Convert the color names to RGB values
rgb_colors = [ImageColor.getrgb(color_name) for color_name in sampled_colors]
# Map category IDs to colors
category_id_to_color = {
category["id"]: color for category, color in zip(categories, rgb_colors)
}
return category_id_to_name, category_id_to_color
def draw_predictions(
boxes, labels, scores, masks, img, model_name, score_threshold, proba_threshold
):
"""
Draw predictions on the input image based on the provided boxes, labels, scores, and masks. Only predictions
with scores above the `score_threshold` will be included, and masks with probabilities exceeding the
`proba_threshold` will be displayed.
Args:
- boxes: numpy.ndarray - an array of bounding box coordinates.
- labels: numpy.ndarray - an array of integers representing the predicted class for each bounding box.
- scores: numpy.ndarray - an array of confidence scores for each bounding box.
- masks: numpy.ndarray - an array of binary masks for each bounding box.
- img: PIL.Image.Image - the input image.
- model_name: str - name of the model given by the dropdown menu, either "facere" or "facere+".
- score_threshold: float - a confidence score threshold for filtering out low-scoring bbox predictions.
- proba_threshold: float - a threshold for filtering out low-probability (pixel-wise) mask predictions.
Returns:
- A list of strings, each representing the path to an image file containing the input image with a different
set of predictions drawn (masks, bounding boxes, masks with bounding box labels and scores).
"""
imgs_list = []
# Map label IDs to names and colors
label_id_to_name, label_id_to_color = process_categories()
# Filter out predictions using thresholds
labels_id = labels[scores > score_threshold].tolist()
if model_name == "facere+":
labels_id = fix_category_id(labels_id)
# models output is in range: [1,class_id+1], hence re-map to: [0,class_id]
labels = [label_id_to_name[int(i) - 1] for i in labels_id]
masks = (masks[scores > score_threshold] > proba_threshold).astype(np.uint8)
boxes = boxes[scores > score_threshold]
# Draw masks to input image and save
img_masks = draw_segmentation_masks(
image=img,
masks=torch.from_numpy(masks.squeeze(1).astype(bool)),
alpha=0.9,
colors=[label_id_to_color[int(i) - 1] for i in labels_id],
)
img_masks = F.to_pil_image(img_masks)
img_masks.save("img_masks.png")
imgs_list.append("img_masks.png")
# Draw bboxes to input image and save
img_bbox = draw_bounding_boxes(img, boxes=torch.from_numpy(boxes), width=4)
img_bbox = F.to_pil_image(img_bbox)
img_bbox.save("img_bbox.png")
imgs_list.append("img_bbox.png")
# Save masks with their bbox labels & bbox scores
for col, (mask, label, score) in enumerate(zip(masks, labels, scores)):
mask = Image.fromarray(mask.squeeze())
plt.imshow(mask)
plt.axis("off")
plt.title(f"{label}: {score:.2f}", fontsize=9)
plt.savefig(f"mask-{col}.png")
plt.close()
imgs_list.append(f"mask-{col}.png")
return imgs_list
def inference(image, model_name, mask_threshold, bbox_threshold):
"""
Load the ONNX model and run inference with the provided input `image`. Visualize the predictions and save them in a
figure, which will be shown in the Gradio app.
"""
# Load image.
img = read_image(image)
# Apply original transformation to the image.
img_transformed = transforms(img)
# Download model
path_onnx = hf_hub_download(
repo_id="rizavelioglu/fashionfail",
filename="facere_plus.onnx" if model_name == "facere+" else "facere_base.onnx"
)
# Create an inference session.
ort_session = onnxruntime.InferenceSession(
path_onnx, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
# compute ONNX Runtime output prediction
ort_inputs = {
ort_session.get_inputs()[0].name: img_transformed.unsqueeze(dim=0).numpy()
}
ort_outs = ort_session.run(None, ort_inputs)
boxes, labels, scores, masks = ort_outs
imgs_list = draw_predictions(boxes, labels, scores, masks, img, model_name,
score_threshold=bbox_threshold, proba_threshold=mask_threshold
)
return imgs_list
title = "Facere - Demo"
description = r"""This is the demo of the paper <a href="https://arxiv.org/abs/2404.08582">FashionFail: Addressing
Failure Cases in Fashion Object Detection and Segmentation</a>. <br>Upload your image and choose the model for inference
from the dropdown menu—either `Facere` or `Facere+` <br> Check out the <a
href="https://rizavelioglu.github.io/fashionfail/">project page</a> for more information."""
article = r"""
Example images are sampled from the `FashionFail-test` set, which the models did not see during training.
<br>**Citation** <br>If you find our work useful in your research, please consider giving a star ⭐ and
a citation:
```
@inproceedings{velioglu2024fashionfail,
author = {Velioglu, Riza and Chan, Robin and Hammer, Barbara},
title = {FashionFail: Addressing Failure Cases in Fashion Object Detection and Segmentation},
journal = {IJCNN},
eprint = {2404.08582},
year = {2024},
}
```
"""
examples = [
["adi_103_6.jpg", "facere", 0.5, 0.7],
["adi_103_6.jpg", "facere+", 0.5, 0.7],
["adi_1201_2.jpg", "facere", 0.5, 0.7],
["adi_1201_2.jpg", "facere+", 0.5, 0.7],
["adi_2149_5.jpg", "facere", 0.5, 0.7],
["adi_2149_5.jpg", "facere+", 0.5, 0.7],
["adi_5476_3.jpg", "facere", 0.5, 0.7],
["adi_5476_3.jpg", "facere+", 0.5, 0.7],
["adi_5641_4.jpg", "facere", 0.5, 0.7],
["adi_5641_4.jpg", "facere+", 0.5, 0.7]
]
demo = gr.Interface(
fn=inference,
inputs=[
gr.Image(type="filepath", label="input"),
gr.Dropdown(["facere", "facere+"], value="facere", label="Models"),
gr.Slider(value=0.5, minimum=0.0, maximum=0.9, step=0.05, label="Mask threshold", info="a threshold for "
"filtering out "
"low-probability ("
"pixel-wise) mask "
"predictions"),
gr.Slider(value=0.7, minimum=0.0, maximum=0.9, step=0.05, label="BBox threshold", info="a threshold for "
"filtering out "
"low-scoring bbox "
"predictions")
],
outputs=gr.Gallery(label="output", preview=True, height=500),
title=title,
description=description,
article=article,
examples=examples,
cache_examples=True,
examples_per_page=6
)
if __name__ == "__main__":
demo.launch()
|