File size: 9,415 Bytes
0fa94d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7a5aaa
0fa94d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d49c62f
0fa94d0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import json
import random

import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime
import torch
import torchvision.transforms.functional as F
from huggingface_hub import hf_hub_download
from PIL import Image, ImageColor
from torchvision.io import read_image
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks

# Load pre-trained model transformations.
weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
transforms = weights.transforms()


def fix_category_id(cat_ids: list):
    # Define the excluded category ids and the remaining ones
    excluded_indices = {2, 12, 16, 19, 20}
    remaining_categories = list(set(range(27)) - excluded_indices)

    # Create a dictionary that maps new IDs to old(original) IDs
    new_id_to_org_id = dict(zip(range(len(remaining_categories)), remaining_categories))

    return [new_id_to_org_id[i-1]+1 for i in cat_ids]


def process_categories() -> tuple:
    """
    Load and process category information from a JSON file.
    Returns a tuple containing two dictionaries: `category_id_to_name` maps category IDs to their names, and
    `category_id_to_color` maps category IDs to a randomly sampled RGB color.

    Returns:
        tuple: A tuple containing two dictionaries:
            - `category_id_to_name`: a dictionary mapping category IDs to their names.
            - `category_id_to_color`: a dictionary mapping category IDs to a randomly sampled RGB color.
    """
    # Load raw categories from JSON file
    with open("categories.json") as fp:
        categories = json.load(fp)

    # Map category IDs to names
    category_id_to_name = {d["id"]: d["name"] for d in categories}

    # Set the seed for the random sampling operation
    random.seed(42)

    # Get a list of all the color names in the PIL colormap
    color_names = list(ImageColor.colormap.keys())

    # Sample 46 unique colors from the list of color names
    sampled_colors = random.sample(color_names, 46)

    # Convert the color names to RGB values
    rgb_colors = [ImageColor.getrgb(color_name) for color_name in sampled_colors]

    # Map category IDs to colors
    category_id_to_color = {
        category["id"]: color for category, color in zip(categories, rgb_colors)
    }

    return category_id_to_name, category_id_to_color


def draw_predictions(
        boxes, labels, scores, masks, img, model_name, score_threshold, proba_threshold
):
    """
    Draw predictions on the input image based on the provided boxes, labels, scores, and masks. Only predictions
    with scores above the `score_threshold` will be included, and masks with probabilities exceeding the
    `proba_threshold` will be displayed.

    Args:
    - boxes: numpy.ndarray - an array of bounding box coordinates.
    - labels: numpy.ndarray - an array of integers representing the predicted class for each bounding box.
    - scores: numpy.ndarray - an array of confidence scores for each bounding box.
    - masks: numpy.ndarray - an array of binary masks for each bounding box.
    - img: PIL.Image.Image - the input image.
    - model_name: str - name of the model given by the dropdown menu, either "facere" or "facere+".
    - score_threshold: float - a confidence score threshold for filtering out low-scoring bbox predictions.
    - proba_threshold: float - a threshold for filtering out low-probability (pixel-wise) mask predictions.

    Returns:
    - A list of strings, each representing the path to an image file containing the input image with a different
      set of predictions drawn (masks, bounding boxes, masks with bounding box labels and scores).
    """
    imgs_list = []

    # Map label IDs to names and colors
    label_id_to_name, label_id_to_color = process_categories()

    # Filter out predictions using thresholds
    labels_id = labels[scores > score_threshold].tolist()
    if model_name == "facere+":
        labels_id = fix_category_id(labels_id)
    # models output is in range: [1,class_id+1], hence re-map to: [0,class_id]
    labels = [label_id_to_name[int(i) - 1] for i in labels_id]
    masks = (masks[scores > score_threshold] > proba_threshold).astype(np.uint8)
    boxes = boxes[scores > score_threshold]

    # Draw masks to input image and save
    img_masks = draw_segmentation_masks(
        image=img,
        masks=torch.from_numpy(masks.squeeze(1).astype(bool)),
        alpha=0.9,
        colors=[label_id_to_color[int(i) - 1] for i in labels_id],
    )
    img_masks = F.to_pil_image(img_masks)
    img_masks.save("img_masks.png")
    imgs_list.append("img_masks.png")

    # Draw bboxes to input image and save
    img_bbox = draw_bounding_boxes(img, boxes=torch.from_numpy(boxes), width=4)
    img_bbox = F.to_pil_image(img_bbox)
    img_bbox.save("img_bbox.png")
    imgs_list.append("img_bbox.png")

    # Save masks with their bbox labels & bbox scores
    for col, (mask, label, score) in enumerate(zip(masks, labels, scores)):
        mask = Image.fromarray(mask.squeeze())
        plt.imshow(mask)
        plt.axis("off")
        plt.title(f"{label}: {score:.2f}", fontsize=9)
        plt.savefig(f"mask-{col}.png")
        plt.close()
        imgs_list.append(f"mask-{col}.png")

    return imgs_list


def inference(image, model_name, mask_threshold, bbox_threshold):
    """
    Load the ONNX model and run inference with the provided input `image`. Visualize the predictions and save them in a
    figure, which will be shown in the Gradio app.
    """
    # Load image.
    img = read_image(image)
    # Apply original transformation to the image.
    img_transformed = transforms(img)

    # Download model
    path_onnx = hf_hub_download(
        repo_id="rizavelioglu/fashionfail",
        filename="facere_plus.onnx" if model_name == "facere+" else "facere_base.onnx"
    )
    # Create an inference session.
    ort_session = onnxruntime.InferenceSession(
        path_onnx, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
    )

    # compute ONNX Runtime output prediction
    ort_inputs = {
        ort_session.get_inputs()[0].name: img_transformed.unsqueeze(dim=0).numpy()
    }
    ort_outs = ort_session.run(None, ort_inputs)

    boxes, labels, scores, masks = ort_outs
    imgs_list = draw_predictions(boxes, labels, scores, masks, img, model_name,
                                 score_threshold=bbox_threshold, proba_threshold=mask_threshold
                                 )

    return imgs_list


title = "Facere - Demo"
description = r"""This is the demo of the paper <a href="https://arxiv.org/abs/2404.08582">FashionFail: Addressing 
Failure Cases in Fashion Object Detection and Segmentation</a>. <br>Upload your image and choose the model for inference
from the dropdown menu—either `Facere` or `Facere+` <br> Check out the <a 
href="https://rizavelioglu.github.io/fashionfail/">project page</a> for more information."""
article = r"""
Example images are sampled from the `FashionFail-test` set, which the models did not see during training.

<br>**Citation** <br>If you find our work useful in your research, please consider giving a star ⭐ and 
a citation: 
``` 
@inproceedings{velioglu2024fashionfail,
  author    = {Velioglu, Riza and Chan, Robin and Hammer, Barbara},
  title     = {FashionFail: Addressing Failure Cases in Fashion Object Detection and Segmentation},
  journal   = {IJCNN},
  eprint    = {2404.08582},
  year      = {2024},
}
``` 
"""

examples = [
    ["adi_103_6.jpg",  "facere", 0.5, 0.7],
    ["adi_103_6.jpg",  "facere+", 0.5, 0.7],
    ["adi_1201_2.jpg", "facere", 0.5, 0.7],
    ["adi_1201_2.jpg", "facere+", 0.5, 0.7],
    ["adi_2149_5.jpg", "facere", 0.5, 0.7],
    ["adi_2149_5.jpg", "facere+", 0.5, 0.7],
    ["adi_5476_3.jpg", "facere", 0.5, 0.7],
    ["adi_5476_3.jpg", "facere+", 0.5, 0.7],
    ["adi_5641_4.jpg", "facere", 0.5, 0.7],
    ["adi_5641_4.jpg", "facere+", 0.5, 0.7]
]

demo = gr.Interface(
    fn=inference,
    inputs=[
        gr.Image(type="filepath", label="input"),
        gr.Dropdown(["facere", "facere+"], value="facere", label="Models"),
        gr.Slider(value=0.5, minimum=0.0, maximum=0.9, step=0.05, label="Mask threshold", info="a threshold for "
                                                                                               "filtering out "
                                                                                               "low-probability ("
                                                                                               "pixel-wise) mask "
                                                                                               "predictions"),
        gr.Slider(value=0.7, minimum=0.0, maximum=0.9, step=0.05, label="BBox threshold", info="a threshold for "
                                                                                               "filtering out "
                                                                                               "low-scoring bbox "
                                                                                               "predictions")
    ],
    outputs=gr.Gallery(label="output", preview=True, height=500),
    title=title,
    description=description,
    article=article,
    examples=examples,
    cache_examples=True,
    examples_per_page=6
)

if __name__ == "__main__":
    demo.launch()