innat commited on
Commit
5637560
1 Parent(s): 238391b
LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Mask R-CNN
2
+
3
+ The MIT License (MIT)
4
+
5
+ Copyright (c) 2017 Matterport, Inc.
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in
15
+ all copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23
+ THE SOFTWARE.
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
- title: Global.Wheat.Detection.MaskRCNN
3
- emoji: 📊
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.0.20
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Wheat Detect Demo
3
+ emoji: 🌾
4
+ colorFrom: green
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 3.0.20
8
+ python_version: 3.7
9
  app_file: app.py
10
  pinned: false
11
  ---
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------ tackle some noisy warning
2
+ import os
3
+ import warnings
4
+
5
+
6
+ def warn(*args, **kwargs):
7
+ pass
8
+
9
+
10
+ warnings.warn = warn
11
+ warnings.filterwarnings("ignore")
12
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
13
+
14
+ import random
15
+
16
+ import gdown
17
+ import gradio as gr
18
+ import matplotlib.pyplot as plt
19
+ import numpy as np
20
+ import tensorflow as tf
21
+ from PIL import Image
22
+
23
+ import mrcnn.model as modellib
24
+ from config import WheatDetectorConfig
25
+ from config import WheatInferenceConfig
26
+ from mrcnn import utils
27
+ from mrcnn import visualize
28
+ from mrcnn.model import log
29
+ from utils import get_ax
30
+
31
+
32
+ # for reproducibility
33
+ def seed_all(SEED):
34
+ random.seed(SEED)
35
+ np.random.seed(SEED)
36
+ os.environ["PYTHONHASHSEED"] = str(SEED)
37
+
38
+
39
+ ORIG_SIZE = 1024
40
+ seed_all(42)
41
+
42
+ config = WheatDetectorConfig()
43
+ inference_config = WheatInferenceConfig()
44
+
45
+
46
+ def get_model_weight(model_id):
47
+ """Get the trained weights."""
48
+ if not os.path.exists("model.h5"):
49
+ model_weight = gdown.download(id=model_id, quiet=False)
50
+ else:
51
+ model_weight = "model.h5"
52
+ return model_weight
53
+
54
+
55
+ def get_model():
56
+ """Get the model."""
57
+ model = modellib.MaskRCNN(mode="inference", config=inference_config, model_dir="./")
58
+ return model
59
+
60
+
61
+ def load_model(model_id):
62
+ """Load trained model."""
63
+ weight = get_model_weight(model_id)
64
+ model = get_model()
65
+ model.load_weights(weight, by_name=True)
66
+ return model
67
+
68
+
69
+ def prepare_image(image):
70
+ """Prepare incoming sample."""
71
+ image = image[:, :, ::-1]
72
+ resize_factor = ORIG_SIZE / config.IMAGE_SHAPE[0]
73
+
74
+ # If grayscale. Convert to RGB for consistency.
75
+ if len(image.shape) != 3 or image.shape[2] != 3:
76
+ image = np.stack((image,) * 3, -1)
77
+
78
+ resized_image, window, scale, padding, crop = utils.resize_image(
79
+ image,
80
+ min_dim=config.IMAGE_MIN_DIM,
81
+ min_scale=config.IMAGE_MIN_SCALE,
82
+ max_dim=config.IMAGE_MAX_DIM,
83
+ mode=config.IMAGE_RESIZE_MODE,
84
+ )
85
+
86
+ return resized_image
87
+
88
+
89
+ def predict_fn(image):
90
+
91
+ image = prepare_image(image)
92
+
93
+ model = load_model(model_id="1k4_WGBAUJCPbkkHkvtscX2jufTqETNYd")
94
+ results = model.detect([image])
95
+ r = results[0]
96
+ class_names = ["Wheat"] * len(r["rois"])
97
+
98
+ image = visualize.display_instances(
99
+ image,
100
+ r["rois"],
101
+ r["masks"],
102
+ r["class_ids"],
103
+ class_names,
104
+ r["scores"],
105
+ ax=get_ax(),
106
+ title="Predictions",
107
+ )
108
+
109
+ return image[:, :, ::-1]
110
+
111
+ title="Global Wheat Detection with Mask-RCNN Model"
112
+ description="<strong>Model</strong>: Mask-RCNN. <strong>Backbone</strong>: ResNet-101. Trained on: <a href='https://www.kaggle.com/competitions/global-wheat-detection/overview'>Global Wheat Detection Dataset (Kaggle)</a>. </br>The code is written in <code>Keras (TensorFlow 1.14)</code>. One can run the full code on Kaggle: <a href='https://www.kaggle.com/code/ipythonx/keras-global-wheat-detection-with-mask-rcnn'>[Keras]:Global Wheat Detection with Mask-RCNN</a>"
113
+ article = "<p>The model received <strong>0.6449</strong> and <strong>0.5675</strong> mAP (0.5:0.75:0.05) on the public and private test dataset respectively. The above examples are from test dataset without ground truth bounding box. Details: <a href='https://www.kaggle.com/competitions/global-wheat-detection/data'>Global Wheat Dataset</a></p>"
114
+
115
+ iface = gr.Interface(
116
+ fn=predict_fn,
117
+ inputs=gr.inputs.Image(label="Input Image"),
118
+ outputs=gr.outputs.Image(label="Prediction"),
119
+ title=title,
120
+ description=description,
121
+ article=article,
122
+ examples=[
123
+ ["examples/2fd875eaa.jpg"],
124
+ ["examples/51b3e36ab.jpg"],
125
+ ["examples/51f1be19e.jpg"],
126
+ ["examples/53f253011.jpg"],
127
+ ["examples/348a992bb.jpg"],
128
+ ["examples/796707dd7.jpg"],
129
+ ["examples/aac893a91.jpg"],
130
+ ["examples/cb8d261a3.jpg"],
131
+ ["examples/cc3532ff6.jpg"],
132
+ ["examples/f5a1f0358.jpg"],
133
+ ],
134
+ )
135
+ iface.launch()
components.py ADDED
The diff for this file is too large to render. See raw diff
 
config.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mrcnn.config import Config
2
+
3
+
4
+ class WheatDetectorConfig(Config):
5
+ # Give the configuration a recognizable name
6
+ NAME = "wheat"
7
+ GPU_COUNT = 1
8
+ IMAGES_PER_GPU = 2
9
+ BACKBONE = "resnet101"
10
+ NUM_CLASSES = 2
11
+ IMAGE_RESIZE_MODE = "square"
12
+ IMAGE_MIN_DIM = 1024
13
+ IMAGE_MAX_DIM = 1024
14
+ STEPS_PER_EPOCH = 120
15
+ BACKBONE_STRIDES = [4, 8, 16, 32, 64]
16
+ RPN_ANCHOR_SCALES = (16, 32, 64, 128, 256)
17
+ LEARNING_RATE = 0.005
18
+ WEIGHT_DECAY = 0.0005
19
+ TRAIN_ROIS_PER_IMAGE = 350
20
+ DETECTION_MIN_CONFIDENCE = 0.60
21
+ VALIDATION_STEPS = 60
22
+ MAX_GT_INSTANCES = 500
23
+ LOSS_WEIGHTS = {
24
+ "rpn_class_loss": 1.0,
25
+ "rpn_bbox_loss": 1.0,
26
+ "mrcnn_class_loss": 1.0,
27
+ "mrcnn_bbox_loss": 1.0,
28
+ "mrcnn_mask_loss": 1.0,
29
+ }
30
+
31
+
32
+ class WheatInferenceConfig(WheatDetectorConfig):
33
+ GPU_COUNT = 1
34
+ IMAGES_PER_GPU = 1
examples/2fd875eaa.jpg ADDED
examples/348a992bb.jpg ADDED
examples/51b3e36ab.jpg ADDED
examples/51f1be19e.jpg ADDED
examples/53f253011.jpg ADDED
examples/796707dd7.jpg ADDED
examples/aac893a91.jpg ADDED
examples/cb8d261a3.jpg ADDED
examples/cc3532ff6.jpg ADDED
examples/f5a1f0358.jpg ADDED
model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:637fb6450e1332ed6447088b8dc68a492c4a8d64782dabeaf6fc4819e3da03e3
3
+ size 255858144
mrcnn/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
mrcnn/config.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mask R-CNN
3
+ Base Configurations class.
4
+
5
+ Copyright (c) 2017 Matterport, Inc.
6
+ Licensed under the MIT License (see LICENSE for details)
7
+ Written by Waleed Abdulla
8
+ """
9
+
10
+ import numpy as np
11
+
12
+ # Base Configuration Class
13
+ # Don't use this class directly. Instead, sub-class it and override
14
+ # the configurations you need to change.
15
+
16
+
17
+ class Config(object):
18
+ """Base configuration class. For custom configurations, create a
19
+ sub-class that inherits from this one and override properties
20
+ that need to be changed.
21
+ """
22
+
23
+ # Name the configurations. For example, 'COCO', 'Experiment 3', ...etc.
24
+ # Useful if your code needs to do things differently depending on which
25
+ # experiment is running.
26
+ NAME = None # Override in sub-classes
27
+
28
+ # NUMBER OF GPUs to use. When using only a CPU, this needs to be set to 1.
29
+ GPU_COUNT = 1
30
+
31
+ # Number of images to train with on each GPU. A 12GB GPU can typically
32
+ # handle 2 images of 1024x1024px.
33
+ # Adjust based on your GPU memory and image sizes. Use the highest
34
+ # number that your GPU can handle for best performance.
35
+ IMAGES_PER_GPU = 2
36
+
37
+ # Number of training steps per epoch
38
+ # This doesn't need to match the size of the training set. Tensorboard
39
+ # updates are saved at the end of each epoch, so setting this to a
40
+ # smaller number means getting more frequent TensorBoard updates.
41
+ # Validation stats are also calculated at each epoch end and they
42
+ # might take a while, so don't set this too small to avoid spending
43
+ # a lot of time on validation stats.
44
+ STEPS_PER_EPOCH = 1000
45
+
46
+ # Number of validation steps to run at the end of every training epoch.
47
+ # A bigger number improves accuracy of validation stats, but slows
48
+ # down the training.
49
+ VALIDATION_STEPS = 50
50
+
51
+ # Backbone network architecture
52
+ # Supported values are: resnet50, resnet101.
53
+ # You can also provide a callable that should have the signature
54
+ # of model.resnet_graph. If you do so, you need to supply a callable
55
+ # to COMPUTE_BACKBONE_SHAPE as well
56
+ BACKBONE = "resnet101"
57
+
58
+ # Only useful if you supply a callable to BACKBONE. Should compute
59
+ # the shape of each layer of the FPN Pyramid.
60
+ # See model.compute_backbone_shapes
61
+ COMPUTE_BACKBONE_SHAPE = None
62
+
63
+ # The strides of each layer of the FPN Pyramid. These values
64
+ # are based on a Resnet101 backbone.
65
+ BACKBONE_STRIDES = [4, 8, 16, 32, 64]
66
+
67
+ # Size of the fully-connected layers in the classification graph
68
+ FPN_CLASSIF_FC_LAYERS_SIZE = 1024
69
+
70
+ # Size of the top-down layers used to build the feature pyramid
71
+ TOP_DOWN_PYRAMID_SIZE = 256
72
+
73
+ # Number of classification classes (including background)
74
+ NUM_CLASSES = 1 # Override in sub-classes
75
+
76
+ # Length of square anchor side in pixels
77
+ RPN_ANCHOR_SCALES = (32, 64, 128, 256, 512)
78
+
79
+ # Ratios of anchors at each cell (width/height)
80
+ # A value of 1 represents a square anchor, and 0.5 is a wide anchor
81
+ RPN_ANCHOR_RATIOS = [0.5, 1, 2]
82
+
83
+ # Anchor stride
84
+ # If 1 then anchors are created for each cell in the backbone feature map.
85
+ # If 2, then anchors are created for every other cell, and so on.
86
+ RPN_ANCHOR_STRIDE = 1
87
+
88
+ # Non-max suppression threshold to filter RPN proposals.
89
+ # You can increase this during training to generate more propsals.
90
+ RPN_NMS_THRESHOLD = 0.7
91
+
92
+ # How many anchors per image to use for RPN training
93
+ RPN_TRAIN_ANCHORS_PER_IMAGE = 256
94
+
95
+ # ROIs kept after tf.nn.top_k and before non-maximum suppression
96
+ PRE_NMS_LIMIT = 6000
97
+
98
+ # ROIs kept after non-maximum suppression (training and inference)
99
+ POST_NMS_ROIS_TRAINING = 2000
100
+ POST_NMS_ROIS_INFERENCE = 1000
101
+
102
+ # If enabled, resizes instance masks to a smaller size to reduce
103
+ # memory load. Recommended when using high-resolution images.
104
+ USE_MINI_MASK = True
105
+ MINI_MASK_SHAPE = (56, 56) # (height, width) of the mini-mask
106
+
107
+ # Input image resizing
108
+ # Generally, use the "square" resizing mode for training and predicting
109
+ # and it should work well in most cases. In this mode, images are scaled
110
+ # up such that the small side is = IMAGE_MIN_DIM, but ensuring that the
111
+ # scaling doesn't make the long side > IMAGE_MAX_DIM. Then the image is
112
+ # padded with zeros to make it a square so multiple images can be put
113
+ # in one batch.
114
+ # Available resizing modes:
115
+ # none: No resizing or padding. Return the image unchanged.
116
+ # square: Resize and pad with zeros to get a square image
117
+ # of size [max_dim, max_dim].
118
+ # pad64: Pads width and height with zeros to make them multiples of 64.
119
+ # If IMAGE_MIN_DIM or IMAGE_MIN_SCALE are not None, then it scales
120
+ # up before padding. IMAGE_MAX_DIM is ignored in this mode.
121
+ # The multiple of 64 is needed to ensure smooth scaling of feature
122
+ # maps up and down the 6 levels of the FPN pyramid (2**6=64).
123
+ # crop: Picks random crops from the image. First, scales the image based
124
+ # on IMAGE_MIN_DIM and IMAGE_MIN_SCALE, then picks a random crop of
125
+ # size IMAGE_MIN_DIM x IMAGE_MIN_DIM. Can be used in training only.
126
+ # IMAGE_MAX_DIM is not used in this mode.
127
+ IMAGE_RESIZE_MODE = "square"
128
+ IMAGE_MIN_DIM = 800
129
+ IMAGE_MAX_DIM = 1024
130
+ # Minimum scaling ratio. Checked after MIN_IMAGE_DIM and can force further
131
+ # up scaling. For example, if set to 2 then images are scaled up to double
132
+ # the width and height, or more, even if MIN_IMAGE_DIM doesn't require it.
133
+ # However, in 'square' mode, it can be overruled by IMAGE_MAX_DIM.
134
+ IMAGE_MIN_SCALE = 0
135
+ # Number of color channels per image. RGB = 3, grayscale = 1, RGB-D = 4
136
+ # Changing this requires other changes in the code. See the WIKI for more
137
+ # details: https://github.com/matterport/Mask_RCNN/wiki
138
+ IMAGE_CHANNEL_COUNT = 3
139
+
140
+ # Image mean (RGB)
141
+ MEAN_PIXEL = np.array([123.7, 116.8, 103.9])
142
+
143
+ # Number of ROIs per image to feed to classifier/mask heads
144
+ # The Mask RCNN paper uses 512 but often the RPN doesn't generate
145
+ # enough positive proposals to fill this and keep a positive:negative
146
+ # ratio of 1:3. You can increase the number of proposals by adjusting
147
+ # the RPN NMS threshold.
148
+ TRAIN_ROIS_PER_IMAGE = 200
149
+
150
+ # Percent of positive ROIs used to train classifier/mask heads
151
+ ROI_POSITIVE_RATIO = 0.33
152
+
153
+ # Pooled ROIs
154
+ POOL_SIZE = 7
155
+ MASK_POOL_SIZE = 14
156
+
157
+ # Shape of output mask
158
+ # To change this you also need to change the neural network mask branch
159
+ MASK_SHAPE = [28, 28]
160
+
161
+ # Maximum number of ground truth instances to use in one image
162
+ MAX_GT_INSTANCES = 100
163
+
164
+ # Bounding box refinement standard deviation for RPN and final detections.
165
+ RPN_BBOX_STD_DEV = np.array([0.1, 0.1, 0.2, 0.2])
166
+ BBOX_STD_DEV = np.array([0.1, 0.1, 0.2, 0.2])
167
+
168
+ # Max number of final detections
169
+ DETECTION_MAX_INSTANCES = 100
170
+
171
+ # Minimum probability value to accept a detected instance
172
+ # ROIs below this threshold are skipped
173
+ DETECTION_MIN_CONFIDENCE = 0.7
174
+
175
+ # Non-maximum suppression threshold for detection
176
+ DETECTION_NMS_THRESHOLD = 0.3
177
+
178
+ # Learning rate and momentum
179
+ # The Mask RCNN paper uses lr=0.02, but on TensorFlow it causes
180
+ # weights to explode. Likely due to differences in optimizer
181
+ # implementation.
182
+ LEARNING_RATE = 0.001
183
+ LEARNING_MOMENTUM = 0.9
184
+
185
+ # Weight decay regularization
186
+ WEIGHT_DECAY = 0.0001
187
+
188
+ # Loss weights for more precise optimization.
189
+ # Can be used for R-CNN training setup.
190
+ LOSS_WEIGHTS = {
191
+ "rpn_class_loss": 1.0,
192
+ "rpn_bbox_loss": 1.0,
193
+ "mrcnn_class_loss": 1.0,
194
+ "mrcnn_bbox_loss": 1.0,
195
+ "mrcnn_mask_loss": 1.0,
196
+ }
197
+
198
+ # Use RPN ROIs or externally generated ROIs for training
199
+ # Keep this True for most situations. Set to False if you want to train
200
+ # the head branches on ROI generated by code rather than the ROIs from
201
+ # the RPN. For example, to debug the classifier head without having to
202
+ # train the RPN.
203
+ USE_RPN_ROIS = True
204
+
205
+ # Train or freeze batch normalization layers
206
+ # None: Train BN layers. This is the normal mode
207
+ # False: Freeze BN layers. Good when using a small batch size
208
+ # True: (don't use). Set layer in training mode even when predicting
209
+ TRAIN_BN = False # Defaulting to False since batch size is often small
210
+
211
+ # Gradient norm clipping
212
+ GRADIENT_CLIP_NORM = 5.0
213
+
214
+ def __init__(self):
215
+ """Set values of computed attributes."""
216
+ # Effective batch size
217
+ self.BATCH_SIZE = self.IMAGES_PER_GPU * self.GPU_COUNT
218
+
219
+ # Input image size
220
+ if self.IMAGE_RESIZE_MODE == "crop":
221
+ self.IMAGE_SHAPE = np.array(
222
+ [self.IMAGE_MIN_DIM, self.IMAGE_MIN_DIM, self.IMAGE_CHANNEL_COUNT]
223
+ )
224
+ else:
225
+ self.IMAGE_SHAPE = np.array(
226
+ [self.IMAGE_MAX_DIM, self.IMAGE_MAX_DIM, self.IMAGE_CHANNEL_COUNT]
227
+ )
228
+
229
+ # Image meta data length
230
+ # See compose_image_meta() for details
231
+ self.IMAGE_META_SIZE = 1 + 3 + 3 + 4 + 1 + self.NUM_CLASSES
232
+
233
+ def display(self):
234
+ """Display Configuration values."""
235
+ print("\nConfigurations:")
236
+ for a in dir(self):
237
+ if not a.startswith("__") and not callable(getattr(self, a)):
238
+ print("{:30} {}".format(a, getattr(self, a)))
239
+ print("\n")
mrcnn/model.py ADDED
The diff for this file is too large to render. See raw diff
 
mrcnn/parallel_model.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mask R-CNN
3
+ Multi-GPU Support for Keras.
4
+
5
+ Copyright (c) 2017 Matterport, Inc.
6
+ Licensed under the MIT License (see LICENSE for details)
7
+ Written by Waleed Abdulla
8
+
9
+ Ideas and a small code snippets from these sources:
10
+ https://github.com/fchollet/keras/issues/2436
11
+ https://medium.com/@kuza55/transparent-multi-gpu-training-on-tensorflow-with-keras-8b0016fd9012
12
+ https://github.com/avolkov1/keras_experiments/blob/master/keras_exp/multigpu/
13
+ https://github.com/fchollet/keras/blob/master/keras/utils/training_utils.py
14
+ """
15
+
16
+ import keras.backend as K
17
+ import keras.layers as KL
18
+ import keras.models as KM
19
+ import tensorflow as tf
20
+
21
+
22
+ class ParallelModel(KM.Model):
23
+ """Subclasses the standard Keras Model and adds multi-GPU support.
24
+ It works by creating a copy of the model on each GPU. Then it slices
25
+ the inputs and sends a slice to each copy of the model, and then
26
+ merges the outputs together and applies the loss on the combined
27
+ outputs.
28
+ """
29
+
30
+ def __init__(self, keras_model, gpu_count):
31
+ """Class constructor.
32
+ keras_model: The Keras model to parallelize
33
+ gpu_count: Number of GPUs. Must be > 1
34
+ """
35
+ self.inner_model = keras_model
36
+ self.gpu_count = gpu_count
37
+ merged_outputs = self.make_parallel()
38
+ super(ParallelModel, self).__init__(
39
+ inputs=self.inner_model.inputs, outputs=merged_outputs
40
+ )
41
+
42
+ def __getattribute__(self, attrname):
43
+ """Redirect loading and saving methods to the inner model. That's where
44
+ the weights are stored."""
45
+ if "load" in attrname or "save" in attrname:
46
+ return getattr(self.inner_model, attrname)
47
+ return super(ParallelModel, self).__getattribute__(attrname)
48
+
49
+ def summary(self, *args, **kwargs):
50
+ """Override summary() to display summaries of both, the wrapper
51
+ and inner models."""
52
+ super(ParallelModel, self).summary(*args, **kwargs)
53
+ self.inner_model.summary(*args, **kwargs)
54
+
55
+ def make_parallel(self):
56
+ """Creates a new wrapper model that consists of multiple replicas of
57
+ the original model placed on different GPUs.
58
+ """
59
+ # Slice inputs. Slice inputs on the CPU to avoid sending a copy
60
+ # of the full inputs to all GPUs. Saves on bandwidth and memory.
61
+ input_slices = {
62
+ name: tf.split(x, self.gpu_count)
63
+ for name, x in zip(self.inner_model.input_names, self.inner_model.inputs)
64
+ }
65
+
66
+ output_names = self.inner_model.output_names
67
+ outputs_all = []
68
+ for i in range(len(self.inner_model.outputs)):
69
+ outputs_all.append([])
70
+
71
+ # Run the model call() on each GPU to place the ops there
72
+ for i in range(self.gpu_count):
73
+ with tf.device("/gpu:%d" % i):
74
+ with tf.name_scope("tower_%d" % i):
75
+ # Run a slice of inputs through this replica
76
+ zipped_inputs = zip(
77
+ self.inner_model.input_names, self.inner_model.inputs
78
+ )
79
+ inputs = [
80
+ KL.Lambda(
81
+ lambda s: input_slices[name][i],
82
+ output_shape=lambda s: (None,) + s[1:],
83
+ )(tensor)
84
+ for name, tensor in zipped_inputs
85
+ ]
86
+ # Create the model replica and get the outputs
87
+ outputs = self.inner_model(inputs)
88
+ if not isinstance(outputs, list):
89
+ outputs = [outputs]
90
+ # Save the outputs for merging back together later
91
+ for l, o in enumerate(outputs):
92
+ outputs_all[l].append(o)
93
+
94
+ # Merge outputs on CPU
95
+ with tf.device("/cpu:0"):
96
+ merged = []
97
+ for outputs, name in zip(outputs_all, output_names):
98
+ # Concatenate or average outputs?
99
+ # Outputs usually have a batch dimension and we concatenate
100
+ # across it. If they don't, then the output is likely a loss
101
+ # or a metric value that gets averaged across the batch.
102
+ # Keras expects losses and metrics to be scalars.
103
+ if K.int_shape(outputs[0]) == ():
104
+ # Average
105
+ m = KL.Lambda(lambda o: tf.add_n(o) / len(outputs), name=name)(
106
+ outputs
107
+ )
108
+ else:
109
+ # Concatenate
110
+ m = KL.Concatenate(axis=0, name=name)(outputs)
111
+ merged.append(m)
112
+ return merged
113
+
114
+
115
+ if __name__ == "__main__":
116
+ # Testing code below. It creates a simple model to train on MNIST and
117
+ # tries to run it on 2 GPUs. It saves the graph so it can be viewed
118
+ # in TensorBoard. Run it as:
119
+ #
120
+ # python3 parallel_model.py
121
+
122
+ import os
123
+
124
+ import keras.optimizers
125
+ import numpy as np
126
+ from keras.datasets import mnist
127
+ from keras.preprocessing.image import ImageDataGenerator
128
+
129
+ GPU_COUNT = 2
130
+
131
+ # Root directory of the project
132
+ ROOT_DIR = os.path.abspath("../")
133
+
134
+ # Directory to save logs and trained model
135
+ MODEL_DIR = os.path.join(ROOT_DIR, "logs")
136
+
137
+ def build_model(x_train, num_classes):
138
+ # Reset default graph. Keras leaves old ops in the graph,
139
+ # which are ignored for execution but clutter graph
140
+ # visualization in TensorBoard.
141
+ tf.reset_default_graph()
142
+
143
+ inputs = KL.Input(shape=x_train.shape[1:], name="input_image")
144
+ x = KL.Conv2D(32, (3, 3), activation="relu", padding="same", name="conv1")(
145
+ inputs
146
+ )
147
+ x = KL.Conv2D(64, (3, 3), activation="relu", padding="same", name="conv2")(x)
148
+ x = KL.MaxPooling2D(pool_size=(2, 2), name="pool1")(x)
149
+ x = KL.Flatten(name="flat1")(x)
150
+ x = KL.Dense(128, activation="relu", name="dense1")(x)
151
+ x = KL.Dense(num_classes, activation="softmax", name="dense2")(x)
152
+
153
+ return KM.Model(inputs, x, "digit_classifier_model")
154
+
155
+ # Load MNIST Data
156
+ (x_train, y_train), (x_test, y_test) = mnist.load_data()
157
+ x_train = np.expand_dims(x_train, -1).astype("float32") / 255
158
+ x_test = np.expand_dims(x_test, -1).astype("float32") / 255
159
+
160
+ print("x_train shape:", x_train.shape)
161
+ print("x_test shape:", x_test.shape)
162
+
163
+ # Build data generator and model
164
+ datagen = ImageDataGenerator()
165
+ model = build_model(x_train, 10)
166
+
167
+ # Add multi-GPU support.
168
+ model = ParallelModel(model, GPU_COUNT)
169
+
170
+ optimizer = keras.optimizers.SGD(lr=0.01, momentum=0.9, clipnorm=5.0)
171
+
172
+ model.compile(
173
+ loss="sparse_categorical_crossentropy",
174
+ optimizer=optimizer,
175
+ metrics=["accuracy"],
176
+ )
177
+
178
+ model.summary()
179
+
180
+ # Train
181
+ model.fit_generator(
182
+ datagen.flow(x_train, y_train, batch_size=64),
183
+ steps_per_epoch=50,
184
+ epochs=10,
185
+ verbose=1,
186
+ validation_data=(x_test, y_test),
187
+ callbacks=[keras.callbacks.TensorBoard(log_dir=MODEL_DIR, write_graph=True)],
188
+ )
mrcnn/utils.py ADDED
@@ -0,0 +1,984 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mask R-CNN
3
+ Common utility functions and classes.
4
+
5
+ Copyright (c) 2017 Matterport, Inc.
6
+ Licensed under the MIT License (see LICENSE for details)
7
+ Written by Waleed Abdulla
8
+ """
9
+
10
+ import logging
11
+ import math
12
+ import os
13
+ import random
14
+ import shutil
15
+ import sys
16
+ import urllib.request
17
+ import warnings
18
+ from distutils.version import LooseVersion
19
+
20
+ import numpy as np
21
+ import scipy
22
+ import skimage.color
23
+ import skimage.io
24
+ import skimage.transform
25
+ import tensorflow as tf
26
+
27
+ # URL from which to download the latest COCO trained weights
28
+ COCO_MODEL_URL = (
29
+ "https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5"
30
+ )
31
+
32
+
33
+ ############################################################
34
+ # Bounding Boxes
35
+ ############################################################
36
+
37
+
38
+ def extract_bboxes(mask):
39
+ """Compute bounding boxes from masks.
40
+ mask: [height, width, num_instances]. Mask pixels are either 1 or 0.
41
+
42
+ Returns: bbox array [num_instances, (y1, x1, y2, x2)].
43
+ """
44
+ boxes = np.zeros([mask.shape[-1], 4], dtype=np.int32)
45
+ for i in range(mask.shape[-1]):
46
+ m = mask[:, :, i]
47
+ # Bounding box.
48
+ horizontal_indicies = np.where(np.any(m, axis=0))[0]
49
+ vertical_indicies = np.where(np.any(m, axis=1))[0]
50
+ if horizontal_indicies.shape[0]:
51
+ x1, x2 = horizontal_indicies[[0, -1]]
52
+ y1, y2 = vertical_indicies[[0, -1]]
53
+ # x2 and y2 should not be part of the box. Increment by 1.
54
+ x2 += 1
55
+ y2 += 1
56
+ else:
57
+ # No mask for this instance. Might happen due to
58
+ # resizing or cropping. Set bbox to zeros
59
+ x1, x2, y1, y2 = 0, 0, 0, 0
60
+ boxes[i] = np.array([y1, x1, y2, x2])
61
+ return boxes.astype(np.int32)
62
+
63
+
64
+ def compute_iou(box, boxes, box_area, boxes_area):
65
+ """Calculates IoU of the given box with the array of the given boxes.
66
+ box: 1D vector [y1, x1, y2, x2]
67
+ boxes: [boxes_count, (y1, x1, y2, x2)]
68
+ box_area: float. the area of 'box'
69
+ boxes_area: array of length boxes_count.
70
+
71
+ Note: the areas are passed in rather than calculated here for
72
+ efficiency. Calculate once in the caller to avoid duplicate work.
73
+ """
74
+ # Calculate intersection areas
75
+ y1 = np.maximum(box[0], boxes[:, 0])
76
+ y2 = np.minimum(box[2], boxes[:, 2])
77
+ x1 = np.maximum(box[1], boxes[:, 1])
78
+ x2 = np.minimum(box[3], boxes[:, 3])
79
+ intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0)
80
+ union = box_area + boxes_area[:] - intersection[:]
81
+ iou = intersection / union
82
+ return iou
83
+
84
+
85
+ def compute_overlaps(boxes1, boxes2):
86
+ """Computes IoU overlaps between two sets of boxes.
87
+ boxes1, boxes2: [N, (y1, x1, y2, x2)].
88
+
89
+ For better performance, pass the largest set first and the smaller second.
90
+ """
91
+ # Areas of anchors and GT boxes
92
+ area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
93
+ area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
94
+
95
+ # Compute overlaps to generate matrix [boxes1 count, boxes2 count]
96
+ # Each cell contains the IoU value.
97
+ overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))
98
+ for i in range(overlaps.shape[1]):
99
+ box2 = boxes2[i]
100
+ overlaps[:, i] = compute_iou(box2, boxes1, area2[i], area1)
101
+ return overlaps
102
+
103
+
104
+ def compute_overlaps_masks(masks1, masks2):
105
+ """Computes IoU overlaps between two sets of masks.
106
+ masks1, masks2: [Height, Width, instances]
107
+ """
108
+
109
+ # If either set of masks is empty return empty result
110
+ if masks1.shape[-1] == 0 or masks2.shape[-1] == 0:
111
+ return np.zeros((masks1.shape[-1], masks2.shape[-1]))
112
+ # flatten masks and compute their areas
113
+ masks1 = np.reshape(masks1 > 0.5, (-1, masks1.shape[-1])).astype(np.float32)
114
+ masks2 = np.reshape(masks2 > 0.5, (-1, masks2.shape[-1])).astype(np.float32)
115
+ area1 = np.sum(masks1, axis=0)
116
+ area2 = np.sum(masks2, axis=0)
117
+
118
+ # intersections and union
119
+ intersections = np.dot(masks1.T, masks2)
120
+ union = area1[:, None] + area2[None, :] - intersections
121
+ overlaps = intersections / union
122
+
123
+ return overlaps
124
+
125
+
126
+ def non_max_suppression(boxes, scores, threshold):
127
+ """Performs non-maximum suppression and returns indices of kept boxes.
128
+ boxes: [N, (y1, x1, y2, x2)]. Notice that (y2, x2) lays outside the box.
129
+ scores: 1-D array of box scores.
130
+ threshold: Float. IoU threshold to use for filtering.
131
+ """
132
+ assert boxes.shape[0] > 0
133
+ if boxes.dtype.kind != "f":
134
+ boxes = boxes.astype(np.float32)
135
+
136
+ # Compute box areas
137
+ y1 = boxes[:, 0]
138
+ x1 = boxes[:, 1]
139
+ y2 = boxes[:, 2]
140
+ x2 = boxes[:, 3]
141
+ area = (y2 - y1) * (x2 - x1)
142
+
143
+ # Get indicies of boxes sorted by scores (highest first)
144
+ ixs = scores.argsort()[::-1]
145
+
146
+ pick = []
147
+ while len(ixs) > 0:
148
+ # Pick top box and add its index to the list
149
+ i = ixs[0]
150
+ pick.append(i)
151
+ # Compute IoU of the picked box with the rest
152
+ iou = compute_iou(boxes[i], boxes[ixs[1:]], area[i], area[ixs[1:]])
153
+ # Identify boxes with IoU over the threshold. This
154
+ # returns indices into ixs[1:], so add 1 to get
155
+ # indices into ixs.
156
+ remove_ixs = np.where(iou > threshold)[0] + 1
157
+ # Remove indices of the picked and overlapped boxes.
158
+ ixs = np.delete(ixs, remove_ixs)
159
+ ixs = np.delete(ixs, 0)
160
+ return np.array(pick, dtype=np.int32)
161
+
162
+
163
+ def apply_box_deltas(boxes, deltas):
164
+ """Applies the given deltas to the given boxes.
165
+ boxes: [N, (y1, x1, y2, x2)]. Note that (y2, x2) is outside the box.
166
+ deltas: [N, (dy, dx, log(dh), log(dw))]
167
+ """
168
+ boxes = boxes.astype(np.float32)
169
+ # Convert to y, x, h, w
170
+ height = boxes[:, 2] - boxes[:, 0]
171
+ width = boxes[:, 3] - boxes[:, 1]
172
+ center_y = boxes[:, 0] + 0.5 * height
173
+ center_x = boxes[:, 1] + 0.5 * width
174
+ # Apply deltas
175
+ center_y += deltas[:, 0] * height
176
+ center_x += deltas[:, 1] * width
177
+ height *= np.exp(deltas[:, 2])
178
+ width *= np.exp(deltas[:, 3])
179
+ # Convert back to y1, x1, y2, x2
180
+ y1 = center_y - 0.5 * height
181
+ x1 = center_x - 0.5 * width
182
+ y2 = y1 + height
183
+ x2 = x1 + width
184
+ return np.stack([y1, x1, y2, x2], axis=1)
185
+
186
+
187
+ def box_refinement_graph(box, gt_box):
188
+ """Compute refinement needed to transform box to gt_box.
189
+ box and gt_box are [N, (y1, x1, y2, x2)]
190
+ """
191
+ box = tf.cast(box, tf.float32)
192
+ gt_box = tf.cast(gt_box, tf.float32)
193
+
194
+ height = box[:, 2] - box[:, 0]
195
+ width = box[:, 3] - box[:, 1]
196
+ center_y = box[:, 0] + 0.5 * height
197
+ center_x = box[:, 1] + 0.5 * width
198
+
199
+ gt_height = gt_box[:, 2] - gt_box[:, 0]
200
+ gt_width = gt_box[:, 3] - gt_box[:, 1]
201
+ gt_center_y = gt_box[:, 0] + 0.5 * gt_height
202
+ gt_center_x = gt_box[:, 1] + 0.5 * gt_width
203
+
204
+ dy = (gt_center_y - center_y) / height
205
+ dx = (gt_center_x - center_x) / width
206
+ dh = tf.log(gt_height / height)
207
+ dw = tf.log(gt_width / width)
208
+
209
+ result = tf.stack([dy, dx, dh, dw], axis=1)
210
+ return result
211
+
212
+
213
+ def box_refinement(box, gt_box):
214
+ """Compute refinement needed to transform box to gt_box.
215
+ box and gt_box are [N, (y1, x1, y2, x2)]. (y2, x2) is
216
+ assumed to be outside the box.
217
+ """
218
+ box = box.astype(np.float32)
219
+ gt_box = gt_box.astype(np.float32)
220
+
221
+ height = box[:, 2] - box[:, 0]
222
+ width = box[:, 3] - box[:, 1]
223
+ center_y = box[:, 0] + 0.5 * height
224
+ center_x = box[:, 1] + 0.5 * width
225
+
226
+ gt_height = gt_box[:, 2] - gt_box[:, 0]
227
+ gt_width = gt_box[:, 3] - gt_box[:, 1]
228
+ gt_center_y = gt_box[:, 0] + 0.5 * gt_height
229
+ gt_center_x = gt_box[:, 1] + 0.5 * gt_width
230
+
231
+ dy = (gt_center_y - center_y) / height
232
+ dx = (gt_center_x - center_x) / width
233
+ dh = np.log(gt_height / height)
234
+ dw = np.log(gt_width / width)
235
+
236
+ return np.stack([dy, dx, dh, dw], axis=1)
237
+
238
+
239
+ ############################################################
240
+ # Dataset
241
+ ############################################################
242
+
243
+
244
+ class Dataset(object):
245
+ """The base class for dataset classes.
246
+ To use it, create a new class that adds functions specific to the dataset
247
+ you want to use. For example:
248
+
249
+ class CatsAndDogsDataset(Dataset):
250
+ def load_cats_and_dogs(self):
251
+ ...
252
+ def load_mask(self, image_id):
253
+ ...
254
+ def image_reference(self, image_id):
255
+ ...
256
+
257
+ See COCODataset and ShapesDataset as examples.
258
+ """
259
+
260
+ def __init__(self, class_map=None):
261
+ self._image_ids = []
262
+ self.image_info = []
263
+ # Background is always the first class
264
+ self.class_info = [{"source": "", "id": 0, "name": "BG"}]
265
+ self.source_class_ids = {}
266
+
267
+ def add_class(self, source, class_id, class_name):
268
+ assert "." not in source, "Source name cannot contain a dot"
269
+ # Does the class exist already?
270
+ for info in self.class_info:
271
+ if info["source"] == source and info["id"] == class_id:
272
+ # source.class_id combination already available, skip
273
+ return
274
+ # Add the class
275
+ self.class_info.append(
276
+ {
277
+ "source": source,
278
+ "id": class_id,
279
+ "name": class_name,
280
+ }
281
+ )
282
+
283
+ def add_image(self, source, image_id, path, **kwargs):
284
+ image_info = {
285
+ "id": image_id,
286
+ "source": source,
287
+ "path": path,
288
+ }
289
+ image_info.update(kwargs)
290
+ self.image_info.append(image_info)
291
+
292
+ def image_reference(self, image_id):
293
+ """Return a link to the image in its source Website or details about
294
+ the image that help looking it up or debugging it.
295
+
296
+ Override for your dataset, but pass to this function
297
+ if you encounter images not in your dataset.
298
+ """
299
+ return ""
300
+
301
+ def prepare(self, class_map=None):
302
+ """Prepares the Dataset class for use.
303
+
304
+ TODO: class map is not supported yet. When done, it should handle mapping
305
+ classes from different datasets to the same class ID.
306
+ """
307
+
308
+ def clean_name(name):
309
+ """Returns a shorter version of object names for cleaner display."""
310
+ return ",".join(name.split(",")[:1])
311
+
312
+ # Build (or rebuild) everything else from the info dicts.
313
+ self.num_classes = len(self.class_info)
314
+ self.class_ids = np.arange(self.num_classes)
315
+ self.class_names = [clean_name(c["name"]) for c in self.class_info]
316
+ self.num_images = len(self.image_info)
317
+ self._image_ids = np.arange(self.num_images)
318
+
319
+ # Mapping from source class and image IDs to internal IDs
320
+ self.class_from_source_map = {
321
+ "{}.{}".format(info["source"], info["id"]): id
322
+ for info, id in zip(self.class_info, self.class_ids)
323
+ }
324
+ self.image_from_source_map = {
325
+ "{}.{}".format(info["source"], info["id"]): id
326
+ for info, id in zip(self.image_info, self.image_ids)
327
+ }
328
+
329
+ # Map sources to class_ids they support
330
+ self.sources = list(set([i["source"] for i in self.class_info]))
331
+ self.source_class_ids = {}
332
+ # Loop over datasets
333
+ for source in self.sources:
334
+ self.source_class_ids[source] = []
335
+ # Find classes that belong to this dataset
336
+ for i, info in enumerate(self.class_info):
337
+ # Include BG class in all datasets
338
+ if i == 0 or source == info["source"]:
339
+ self.source_class_ids[source].append(i)
340
+
341
+ def map_source_class_id(self, source_class_id):
342
+ """Takes a source class ID and returns the int class ID assigned to it.
343
+
344
+ For example:
345
+ dataset.map_source_class_id("coco.12") -> 23
346
+ """
347
+ return self.class_from_source_map[source_class_id]
348
+
349
+ def get_source_class_id(self, class_id, source):
350
+ """Map an internal class ID to the corresponding class ID in the source dataset."""
351
+ info = self.class_info[class_id]
352
+ assert info["source"] == source
353
+ return info["id"]
354
+
355
+ @property
356
+ def image_ids(self):
357
+ return self._image_ids
358
+
359
+ def source_image_link(self, image_id):
360
+ """Returns the path or URL to the image.
361
+ Override this to return a URL to the image if it's available online for easy
362
+ debugging.
363
+ """
364
+ return self.image_info[image_id]["path"]
365
+
366
+ def load_image(self, image_id):
367
+ """Load the specified image and return a [H,W,3] Numpy array."""
368
+ # Load image
369
+ image = skimage.io.imread(self.image_info[image_id]["path"])
370
+ # If grayscale. Convert to RGB for consistency.
371
+ if image.ndim != 3:
372
+ image = skimage.color.gray2rgb(image)
373
+ # If has an alpha channel, remove it for consistency
374
+ if image.shape[-1] == 4:
375
+ image = image[..., :3]
376
+ return image
377
+
378
+ def load_mask(self, image_id):
379
+ """Load instance masks for the given image.
380
+
381
+ Different datasets use different ways to store masks. Override this
382
+ method to load instance masks and return them in the form of am
383
+ array of binary masks of shape [height, width, instances].
384
+
385
+ Returns:
386
+ masks: A bool array of shape [height, width, instance count] with
387
+ a binary mask per instance.
388
+ class_ids: a 1D array of class IDs of the instance masks.
389
+ """
390
+ # Override this function to load a mask from your dataset.
391
+ # Otherwise, it returns an empty mask.
392
+ logging.warning(
393
+ "You are using the default load_mask(), maybe you need to define your own one."
394
+ )
395
+ mask = np.empty([0, 0, 0])
396
+ class_ids = np.empty([0], np.int32)
397
+ return mask, class_ids
398
+
399
+
400
+ def resize_image(image, min_dim=None, max_dim=None, min_scale=None, mode="square"):
401
+ """Resizes an image keeping the aspect ratio unchanged.
402
+
403
+ min_dim: if provided, resizes the image such that it's smaller
404
+ dimension == min_dim
405
+ max_dim: if provided, ensures that the image longest side doesn't
406
+ exceed this value.
407
+ min_scale: if provided, ensure that the image is scaled up by at least
408
+ this percent even if min_dim doesn't require it.
409
+ mode: Resizing mode.
410
+ none: No resizing. Return the image unchanged.
411
+ square: Resize and pad with zeros to get a square image
412
+ of size [max_dim, max_dim].
413
+ pad64: Pads width and height with zeros to make them multiples of 64.
414
+ If min_dim or min_scale are provided, it scales the image up
415
+ before padding. max_dim is ignored in this mode.
416
+ The multiple of 64 is needed to ensure smooth scaling of feature
417
+ maps up and down the 6 levels of the FPN pyramid (2**6=64).
418
+ crop: Picks random crops from the image. First, scales the image based
419
+ on min_dim and min_scale, then picks a random crop of
420
+ size min_dim x min_dim. Can be used in training only.
421
+ max_dim is not used in this mode.
422
+
423
+ Returns:
424
+ image: the resized image
425
+ window: (y1, x1, y2, x2). If max_dim is provided, padding might
426
+ be inserted in the returned image. If so, this window is the
427
+ coordinates of the image part of the full image (excluding
428
+ the padding). The x2, y2 pixels are not included.
429
+ scale: The scale factor used to resize the image
430
+ padding: Padding added to the image [(top, bottom), (left, right), (0, 0)]
431
+ """
432
+ # Keep track of image dtype and return results in the same dtype
433
+ image_dtype = image.dtype
434
+ # Default window (y1, x1, y2, x2) and default scale == 1.
435
+ h, w = image.shape[:2]
436
+ window = (0, 0, h, w)
437
+ scale = 1
438
+ padding = [(0, 0), (0, 0), (0, 0)]
439
+ crop = None
440
+
441
+ if mode == "none":
442
+ return image, window, scale, padding, crop
443
+
444
+ # Scale?
445
+ if min_dim:
446
+ # Scale up but not down
447
+ print(min_dim, min(h, w), type(min_dim), type(min(h, w)))
448
+ scale = max(1, min_dim / min(h, w))
449
+ if min_scale and scale < min_scale:
450
+ scale = min_scale
451
+
452
+ # Does it exceed max dim?
453
+ if max_dim and mode == "square":
454
+ image_max = max(h, w)
455
+ if round(image_max * scale) > max_dim:
456
+ scale = max_dim / image_max
457
+
458
+ # Resize image using bilinear interpolation
459
+ if scale != 1:
460
+ image = resize(image, (round(h * scale), round(w * scale)), preserve_range=True)
461
+
462
+ # Need padding or cropping?
463
+ if mode == "square":
464
+ # Get new height and width
465
+ h, w = image.shape[:2]
466
+ top_pad = (max_dim - h) // 2
467
+ bottom_pad = max_dim - h - top_pad
468
+ left_pad = (max_dim - w) // 2
469
+ right_pad = max_dim - w - left_pad
470
+ padding = [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)]
471
+ image = np.pad(image, padding, mode="constant", constant_values=0)
472
+ window = (top_pad, left_pad, h + top_pad, w + left_pad)
473
+ elif mode == "pad64":
474
+ h, w = image.shape[:2]
475
+ # Both sides must be divisible by 64
476
+ assert min_dim % 64 == 0, "Minimum dimension must be a multiple of 64"
477
+ # Height
478
+ if h % 64 > 0:
479
+ max_h = h - (h % 64) + 64
480
+ top_pad = (max_h - h) // 2
481
+ bottom_pad = max_h - h - top_pad
482
+ else:
483
+ top_pad = bottom_pad = 0
484
+ # Width
485
+ if w % 64 > 0:
486
+ max_w = w - (w % 64) + 64
487
+ left_pad = (max_w - w) // 2
488
+ right_pad = max_w - w - left_pad
489
+ else:
490
+ left_pad = right_pad = 0
491
+ padding = [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)]
492
+ image = np.pad(image, padding, mode="constant", constant_values=0)
493
+ window = (top_pad, left_pad, h + top_pad, w + left_pad)
494
+ elif mode == "crop":
495
+ # Pick a random crop
496
+ h, w = image.shape[:2]
497
+ y = random.randint(0, (h - min_dim))
498
+ x = random.randint(0, (w - min_dim))
499
+ crop = (y, x, min_dim, min_dim)
500
+ image = image[y : y + min_dim, x : x + min_dim]
501
+ window = (0, 0, min_dim, min_dim)
502
+ else:
503
+ raise Exception("Mode {} not supported".format(mode))
504
+ return image.astype(image_dtype), window, scale, padding, crop
505
+
506
+
507
+ def resize_mask(mask, scale, padding, crop=None):
508
+ """Resizes a mask using the given scale and padding.
509
+ Typically, you get the scale and padding from resize_image() to
510
+ ensure both, the image and the mask, are resized consistently.
511
+
512
+ scale: mask scaling factor
513
+ padding: Padding to add to the mask in the form
514
+ [(top, bottom), (left, right), (0, 0)]
515
+ """
516
+ # Suppress warning from scipy 0.13.0, the output shape of zoom() is
517
+ # calculated with round() instead of int()
518
+ with warnings.catch_warnings():
519
+ warnings.simplefilter("ignore")
520
+ mask = scipy.ndimage.zoom(mask, zoom=[scale, scale, 1], order=0)
521
+ if crop is not None:
522
+ y, x, h, w = crop
523
+ mask = mask[y : y + h, x : x + w]
524
+ else:
525
+ mask = np.pad(mask, padding, mode="constant", constant_values=0)
526
+ return mask
527
+
528
+
529
+ def minimize_mask(bbox, mask, mini_shape):
530
+ """Resize masks to a smaller version to reduce memory load.
531
+ Mini-masks can be resized back to image scale using expand_masks()
532
+
533
+ See inspect_data.ipynb notebook for more details.
534
+ """
535
+ mini_mask = np.zeros(mini_shape + (mask.shape[-1],), dtype=bool)
536
+ for i in range(mask.shape[-1]):
537
+ # Pick slice and cast to bool in case load_mask() returned wrong dtype
538
+ m = mask[:, :, i].astype(bool)
539
+ y1, x1, y2, x2 = bbox[i][:4]
540
+ m = m[y1:y2, x1:x2]
541
+ if m.size == 0:
542
+ raise Exception("Invalid bounding box with area of zero")
543
+ # Resize with bilinear interpolation
544
+ m = resize(m, mini_shape)
545
+ mini_mask[:, :, i] = np.around(m).astype(np.bool)
546
+ return mini_mask
547
+
548
+
549
+ def expand_mask(bbox, mini_mask, image_shape):
550
+ """Resizes mini masks back to image size. Reverses the change
551
+ of minimize_mask().
552
+
553
+ See inspect_data.ipynb notebook for more details.
554
+ """
555
+ mask = np.zeros(image_shape[:2] + (mini_mask.shape[-1],), dtype=bool)
556
+ for i in range(mask.shape[-1]):
557
+ m = mini_mask[:, :, i]
558
+ y1, x1, y2, x2 = bbox[i][:4]
559
+ h = y2 - y1
560
+ w = x2 - x1
561
+ # Resize with bilinear interpolation
562
+ m = resize(m, (h, w))
563
+ mask[y1:y2, x1:x2, i] = np.around(m).astype(np.bool)
564
+ return mask
565
+
566
+
567
+ # TODO: Build and use this function to reduce code duplication
568
+ def mold_mask(mask, config):
569
+ pass
570
+
571
+
572
+ def unmold_mask(mask, bbox, image_shape):
573
+ """Converts a mask generated by the neural network to a format similar
574
+ to its original shape.
575
+ mask: [height, width] of type float. A small, typically 28x28 mask.
576
+ bbox: [y1, x1, y2, x2]. The box to fit the mask in.
577
+
578
+ Returns a binary mask with the same size as the original image.
579
+ """
580
+ threshold = 0.5
581
+ y1, x1, y2, x2 = bbox
582
+ mask = resize(mask, (y2 - y1, x2 - x1))
583
+ mask = np.where(mask >= threshold, 1, 0).astype(np.bool)
584
+
585
+ # Put the mask in the right location.
586
+ full_mask = np.zeros(image_shape[:2], dtype=np.bool)
587
+ full_mask[y1:y2, x1:x2] = mask
588
+ return full_mask
589
+
590
+
591
+ ############################################################
592
+ # Anchors
593
+ ############################################################
594
+
595
+
596
+ def generate_anchors(scales, ratios, shape, feature_stride, anchor_stride):
597
+ """
598
+ scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128]
599
+ ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2]
600
+ shape: [height, width] spatial shape of the feature map over which
601
+ to generate anchors.
602
+ feature_stride: Stride of the feature map relative to the image in pixels.
603
+ anchor_stride: Stride of anchors on the feature map. For example, if the
604
+ value is 2 then generate anchors for every other feature map pixel.
605
+ """
606
+ # Get all combinations of scales and ratios
607
+ scales, ratios = np.meshgrid(np.array(scales), np.array(ratios))
608
+ scales = scales.flatten()
609
+ ratios = ratios.flatten()
610
+
611
+ # Enumerate heights and widths from scales and ratios
612
+ heights = scales / np.sqrt(ratios)
613
+ widths = scales * np.sqrt(ratios)
614
+
615
+ # Enumerate shifts in feature space
616
+ shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride
617
+ shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride
618
+ shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y)
619
+
620
+ # Enumerate combinations of shifts, widths, and heights
621
+ box_widths, box_centers_x = np.meshgrid(widths, shifts_x)
622
+ box_heights, box_centers_y = np.meshgrid(heights, shifts_y)
623
+
624
+ # Reshape to get a list of (y, x) and a list of (h, w)
625
+ box_centers = np.stack([box_centers_y, box_centers_x], axis=2).reshape([-1, 2])
626
+ box_sizes = np.stack([box_heights, box_widths], axis=2).reshape([-1, 2])
627
+
628
+ # Convert to corner coordinates (y1, x1, y2, x2)
629
+ boxes = np.concatenate(
630
+ [box_centers - 0.5 * box_sizes, box_centers + 0.5 * box_sizes], axis=1
631
+ )
632
+ return boxes
633
+
634
+
635
+ def generate_pyramid_anchors(
636
+ scales, ratios, feature_shapes, feature_strides, anchor_stride
637
+ ):
638
+ """Generate anchors at different levels of a feature pyramid. Each scale
639
+ is associated with a level of the pyramid, but each ratio is used in
640
+ all levels of the pyramid.
641
+
642
+ Returns:
643
+ anchors: [N, (y1, x1, y2, x2)]. All generated anchors in one array. Sorted
644
+ with the same order of the given scales. So, anchors of scale[0] come
645
+ first, then anchors of scale[1], and so on.
646
+ """
647
+ # Anchors
648
+ # [anchor_count, (y1, x1, y2, x2)]
649
+ anchors = []
650
+ for i in range(len(scales)):
651
+ anchors.append(
652
+ generate_anchors(
653
+ scales[i], ratios, feature_shapes[i], feature_strides[i], anchor_stride
654
+ )
655
+ )
656
+ return np.concatenate(anchors, axis=0)
657
+
658
+
659
+ ############################################################
660
+ # Miscellaneous
661
+ ############################################################
662
+
663
+
664
+ def trim_zeros(x):
665
+ """It's common to have tensors larger than the available data and
666
+ pad with zeros. This function removes rows that are all zeros.
667
+
668
+ x: [rows, columns].
669
+ """
670
+ assert len(x.shape) == 2
671
+ return x[~np.all(x == 0, axis=1)]
672
+
673
+
674
+ def compute_matches(
675
+ gt_boxes,
676
+ gt_class_ids,
677
+ gt_masks,
678
+ pred_boxes,
679
+ pred_class_ids,
680
+ pred_scores,
681
+ pred_masks,
682
+ iou_threshold=0.5,
683
+ score_threshold=0.0,
684
+ ):
685
+ """Finds matches between prediction and ground truth instances.
686
+
687
+ Returns:
688
+ gt_match: 1-D array. For each GT box it has the index of the matched
689
+ predicted box.
690
+ pred_match: 1-D array. For each predicted box, it has the index of
691
+ the matched ground truth box.
692
+ overlaps: [pred_boxes, gt_boxes] IoU overlaps.
693
+ """
694
+ # Trim zero padding
695
+ # TODO: cleaner to do zero unpadding upstream
696
+ gt_boxes = trim_zeros(gt_boxes)
697
+ gt_masks = gt_masks[..., : gt_boxes.shape[0]]
698
+ pred_boxes = trim_zeros(pred_boxes)
699
+ pred_scores = pred_scores[: pred_boxes.shape[0]]
700
+ # Sort predictions by score from high to low
701
+ indices = np.argsort(pred_scores)[::-1]
702
+ pred_boxes = pred_boxes[indices]
703
+ pred_class_ids = pred_class_ids[indices]
704
+ pred_scores = pred_scores[indices]
705
+ pred_masks = pred_masks[..., indices]
706
+
707
+ # Compute IoU overlaps [pred_masks, gt_masks]
708
+ overlaps = compute_overlaps_masks(pred_masks, gt_masks)
709
+
710
+ # Loop through predictions and find matching ground truth boxes
711
+ match_count = 0
712
+ pred_match = -1 * np.ones([pred_boxes.shape[0]])
713
+ gt_match = -1 * np.ones([gt_boxes.shape[0]])
714
+ for i in range(len(pred_boxes)):
715
+ # Find best matching ground truth box
716
+ # 1. Sort matches by score
717
+ sorted_ixs = np.argsort(overlaps[i])[::-1]
718
+ # 2. Remove low scores
719
+ low_score_idx = np.where(overlaps[i, sorted_ixs] < score_threshold)[0]
720
+ if low_score_idx.size > 0:
721
+ sorted_ixs = sorted_ixs[: low_score_idx[0]]
722
+ # 3. Find the match
723
+ for j in sorted_ixs:
724
+ # If ground truth box is already matched, go to next one
725
+ if gt_match[j] > -1:
726
+ continue
727
+ # If we reach IoU smaller than the threshold, end the loop
728
+ iou = overlaps[i, j]
729
+ if iou < iou_threshold:
730
+ break
731
+ # Do we have a match?
732
+ if pred_class_ids[i] == gt_class_ids[j]:
733
+ match_count += 1
734
+ gt_match[j] = i
735
+ pred_match[i] = j
736
+ break
737
+
738
+ return gt_match, pred_match, overlaps
739
+
740
+
741
+ def compute_ap(
742
+ gt_boxes,
743
+ gt_class_ids,
744
+ gt_masks,
745
+ pred_boxes,
746
+ pred_class_ids,
747
+ pred_scores,
748
+ pred_masks,
749
+ iou_threshold=0.5,
750
+ ):
751
+ """Compute Average Precision at a set IoU threshold (default 0.5).
752
+
753
+ Returns:
754
+ mAP: Mean Average Precision
755
+ precisions: List of precisions at different class score thresholds.
756
+ recalls: List of recall values at different class score thresholds.
757
+ overlaps: [pred_boxes, gt_boxes] IoU overlaps.
758
+ """
759
+ # Get matches and overlaps
760
+ gt_match, pred_match, overlaps = compute_matches(
761
+ gt_boxes,
762
+ gt_class_ids,
763
+ gt_masks,
764
+ pred_boxes,
765
+ pred_class_ids,
766
+ pred_scores,
767
+ pred_masks,
768
+ iou_threshold,
769
+ )
770
+
771
+ # Compute precision and recall at each prediction box step
772
+ precisions = np.cumsum(pred_match > -1) / (np.arange(len(pred_match)) + 1)
773
+ recalls = np.cumsum(pred_match > -1).astype(np.float32) / len(gt_match)
774
+
775
+ # Pad with start and end values to simplify the math
776
+ precisions = np.concatenate([[0], precisions, [0]])
777
+ recalls = np.concatenate([[0], recalls, [1]])
778
+
779
+ # Ensure precision values decrease but don't increase. This way, the
780
+ # precision value at each recall threshold is the maximum it can be
781
+ # for all following recall thresholds, as specified by the VOC paper.
782
+ for i in range(len(precisions) - 2, -1, -1):
783
+ precisions[i] = np.maximum(precisions[i], precisions[i + 1])
784
+
785
+ # Compute mean AP over recall range
786
+ indices = np.where(recalls[:-1] != recalls[1:])[0] + 1
787
+ mAP = np.sum((recalls[indices] - recalls[indices - 1]) * precisions[indices])
788
+
789
+ return mAP, precisions, recalls, overlaps
790
+
791
+
792
+ def compute_ap_range(
793
+ gt_box,
794
+ gt_class_id,
795
+ gt_mask,
796
+ pred_box,
797
+ pred_class_id,
798
+ pred_score,
799
+ pred_mask,
800
+ iou_thresholds=None,
801
+ verbose=1,
802
+ ):
803
+ """Compute AP over a range or IoU thresholds. Default range is 0.5-0.95."""
804
+ # Default is 0.5 to 0.95 with increments of 0.05
805
+ iou_thresholds = iou_thresholds or np.arange(0.5, 1.0, 0.05)
806
+
807
+ # Compute AP over range of IoU thresholds
808
+ AP = []
809
+ for iou_threshold in iou_thresholds:
810
+ ap, precisions, recalls, overlaps = compute_ap(
811
+ gt_box,
812
+ gt_class_id,
813
+ gt_mask,
814
+ pred_box,
815
+ pred_class_id,
816
+ pred_score,
817
+ pred_mask,
818
+ iou_threshold=iou_threshold,
819
+ )
820
+ if verbose:
821
+ print("AP @{:.2f}:\t {:.3f}".format(iou_threshold, ap))
822
+ AP.append(ap)
823
+ AP = np.array(AP).mean()
824
+ if verbose:
825
+ print(
826
+ "AP @{:.2f}-{:.2f}:\t {:.3f}".format(
827
+ iou_thresholds[0], iou_thresholds[-1], AP
828
+ )
829
+ )
830
+ return AP
831
+
832
+
833
+ def compute_recall(pred_boxes, gt_boxes, iou):
834
+ """Compute the recall at the given IoU threshold. It's an indication
835
+ of how many GT boxes were found by the given prediction boxes.
836
+
837
+ pred_boxes: [N, (y1, x1, y2, x2)] in image coordinates
838
+ gt_boxes: [N, (y1, x1, y2, x2)] in image coordinates
839
+ """
840
+ # Measure overlaps
841
+ overlaps = compute_overlaps(pred_boxes, gt_boxes)
842
+ iou_max = np.max(overlaps, axis=1)
843
+ iou_argmax = np.argmax(overlaps, axis=1)
844
+ positive_ids = np.where(iou_max >= iou)[0]
845
+ matched_gt_boxes = iou_argmax[positive_ids]
846
+
847
+ recall = len(set(matched_gt_boxes)) / gt_boxes.shape[0]
848
+ return recall, positive_ids
849
+
850
+
851
+ # ## Batch Slicing
852
+ # Some custom layers support a batch size of 1 only, and require a lot of work
853
+ # to support batches greater than 1. This function slices an input tensor
854
+ # across the batch dimension and feeds batches of size 1. Effectively,
855
+ # an easy way to support batches > 1 quickly with little code modification.
856
+ # In the long run, it's more efficient to modify the code to support large
857
+ # batches and getting rid of this function. Consider this a temporary solution
858
+ def batch_slice(inputs, graph_fn, batch_size, names=None):
859
+ """Splits inputs into slices and feeds each slice to a copy of the given
860
+ computation graph and then combines the results. It allows you to run a
861
+ graph on a batch of inputs even if the graph is written to support one
862
+ instance only.
863
+
864
+ inputs: list of tensors. All must have the same first dimension length
865
+ graph_fn: A function that returns a TF tensor that's part of a graph.
866
+ batch_size: number of slices to divide the data into.
867
+ names: If provided, assigns names to the resulting tensors.
868
+ """
869
+ if not isinstance(inputs, list):
870
+ inputs = [inputs]
871
+
872
+ outputs = []
873
+ for i in range(batch_size):
874
+ inputs_slice = [x[i] for x in inputs]
875
+ output_slice = graph_fn(*inputs_slice)
876
+ if not isinstance(output_slice, (tuple, list)):
877
+ output_slice = [output_slice]
878
+ outputs.append(output_slice)
879
+ # Change outputs from a list of slices where each is
880
+ # a list of outputs to a list of outputs and each has
881
+ # a list of slices
882
+ outputs = list(zip(*outputs))
883
+
884
+ if names is None:
885
+ names = [None] * len(outputs)
886
+
887
+ result = [tf.stack(o, axis=0, name=n) for o, n in zip(outputs, names)]
888
+ if len(result) == 1:
889
+ result = result[0]
890
+
891
+ return result
892
+
893
+
894
+ def download_trained_weights(coco_model_path, verbose=1):
895
+ """Download COCO trained weights from Releases.
896
+
897
+ coco_model_path: local path of COCO trained weights
898
+ """
899
+ if verbose > 0:
900
+ print("Downloading pretrained model to " + coco_model_path + " ...")
901
+ with urllib.request.urlopen(COCO_MODEL_URL) as resp, open(
902
+ coco_model_path, "wb"
903
+ ) as out:
904
+ shutil.copyfileobj(resp, out)
905
+ if verbose > 0:
906
+ print("... done downloading pretrained model!")
907
+
908
+
909
+ def norm_boxes(boxes, shape):
910
+ """Converts boxes from pixel coordinates to normalized coordinates.
911
+ boxes: [N, (y1, x1, y2, x2)] in pixel coordinates
912
+ shape: [..., (height, width)] in pixels
913
+
914
+ Note: In pixel coordinates (y2, x2) is outside the box. But in normalized
915
+ coordinates it's inside the box.
916
+
917
+ Returns:
918
+ [N, (y1, x1, y2, x2)] in normalized coordinates
919
+ """
920
+ h, w = shape
921
+ scale = np.array([h - 1, w - 1, h - 1, w - 1])
922
+ shift = np.array([0, 0, 1, 1])
923
+ return np.divide((boxes - shift), scale).astype(np.float32)
924
+
925
+
926
+ def denorm_boxes(boxes, shape):
927
+ """Converts boxes from normalized coordinates to pixel coordinates.
928
+ boxes: [N, (y1, x1, y2, x2)] in normalized coordinates
929
+ shape: [..., (height, width)] in pixels
930
+
931
+ Note: In pixel coordinates (y2, x2) is outside the box. But in normalized
932
+ coordinates it's inside the box.
933
+
934
+ Returns:
935
+ [N, (y1, x1, y2, x2)] in pixel coordinates
936
+ """
937
+ h, w = shape
938
+ scale = np.array([h - 1, w - 1, h - 1, w - 1])
939
+ shift = np.array([0, 0, 1, 1])
940
+ return np.around(np.multiply(boxes, scale) + shift).astype(np.int32)
941
+
942
+
943
+ def resize(
944
+ image,
945
+ output_shape,
946
+ order=1,
947
+ mode="constant",
948
+ cval=0,
949
+ clip=True,
950
+ preserve_range=False,
951
+ anti_aliasing=False,
952
+ anti_aliasing_sigma=None,
953
+ ):
954
+ """A wrapper for Scikit-Image resize().
955
+
956
+ Scikit-Image generates warnings on every call to resize() if it doesn't
957
+ receive the right parameters. The right parameters depend on the version
958
+ of skimage. This solves the problem by using different parameters per
959
+ version. And it provides a central place to control resizing defaults.
960
+ """
961
+ if LooseVersion(skimage.__version__) >= LooseVersion("0.14"):
962
+ # New in 0.14: anti_aliasing. Default it to False for backward
963
+ # compatibility with skimage 0.13.
964
+ return skimage.transform.resize(
965
+ image,
966
+ output_shape,
967
+ order=order,
968
+ mode=mode,
969
+ cval=cval,
970
+ clip=clip,
971
+ preserve_range=preserve_range,
972
+ anti_aliasing=anti_aliasing,
973
+ anti_aliasing_sigma=anti_aliasing_sigma,
974
+ )
975
+ else:
976
+ return skimage.transform.resize(
977
+ image,
978
+ output_shape,
979
+ order=order,
980
+ mode=mode,
981
+ cval=cval,
982
+ clip=clip,
983
+ preserve_range=preserve_range,
984
+ )
mrcnn/visualize.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mask R-CNN
3
+ Display and Visualization Functions.
4
+
5
+ Copyright (c) 2017 Matterport, Inc.
6
+ Licensed under the MIT License (see LICENSE for details)
7
+ Written by Waleed Abdulla
8
+ """
9
+
10
+ import colorsys
11
+ import itertools
12
+ import os
13
+ import random
14
+ import sys
15
+
16
+ import IPython.display
17
+ import matplotlib.pyplot as plt
18
+ import numpy as np
19
+ from matplotlib import lines
20
+ from matplotlib import patches
21
+ from matplotlib.patches import Polygon
22
+ from skimage.measure import find_contours
23
+
24
+ # Root directory of the project
25
+ ROOT_DIR = os.path.abspath("../")
26
+
27
+ # Import Mask RCNN
28
+ sys.path.append(ROOT_DIR) # To find local version of the library
29
+ from mrcnn import utils
30
+
31
+ ############################################################
32
+ # Visualization
33
+ ############################################################
34
+
35
+
36
+ def display_images(
37
+ images, titles=None, cols=4, cmap=None, norm=None, interpolation=None
38
+ ):
39
+ """Display the given set of images, optionally with titles.
40
+ images: list or array of image tensors in HWC format.
41
+ titles: optional. A list of titles to display with each image.
42
+ cols: number of images per row
43
+ cmap: Optional. Color map to use. For example, "Blues".
44
+ norm: Optional. A Normalize instance to map values to colors.
45
+ interpolation: Optional. Image interpolation to use for display.
46
+ """
47
+ titles = titles if titles is not None else [""] * len(images)
48
+ rows = len(images) // cols + 1
49
+ plt.figure(figsize=(14, 14 * rows // cols))
50
+ i = 1
51
+ for image, title in zip(images, titles):
52
+ plt.subplot(rows, cols, i)
53
+ plt.title(title, fontsize=9)
54
+ plt.axis("off")
55
+ plt.imshow(
56
+ image.astype(np.uint8), cmap=cmap, norm=norm, interpolation=interpolation
57
+ )
58
+ i += 1
59
+ plt.show()
60
+
61
+
62
+ def random_colors(N, bright=True):
63
+ """
64
+ Generate random colors.
65
+ To get visually distinct colors, generate them in HSV space then
66
+ convert to RGB.
67
+ """
68
+ brightness = 1.0 if bright else 0.7
69
+ hsv = [(i / N, 1, brightness) for i in range(N)]
70
+ colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
71
+ random.shuffle(colors)
72
+ return colors
73
+
74
+
75
+ def apply_mask(image, mask, color, alpha=0.5):
76
+ """Apply the given mask to the image."""
77
+ for c in range(3):
78
+ image[:, :, c] = np.where(
79
+ mask == 1,
80
+ image[:, :, c] * (1 - alpha) + alpha * color[c] * 255,
81
+ image[:, :, c],
82
+ )
83
+ return image
84
+
85
+
86
+ def display_instances(
87
+ image,
88
+ boxes,
89
+ masks,
90
+ class_ids,
91
+ class_names,
92
+ scores=None,
93
+ title="",
94
+ figsize=(16, 16),
95
+ ax=None,
96
+ show_mask=True,
97
+ show_bbox=True,
98
+ colors=None,
99
+ captions=None,
100
+ ):
101
+ """
102
+ boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates.
103
+ masks: [height, width, num_instances]
104
+ class_ids: [num_instances]
105
+ class_names: list of class names of the dataset
106
+ scores: (optional) confidence scores for each box
107
+ title: (optional) Figure title
108
+ show_mask, show_bbox: To show masks and bounding boxes or not
109
+ figsize: (optional) the size of the image
110
+ colors: (optional) An array or colors to use with each object
111
+ captions: (optional) A list of strings to use as captions for each object
112
+ """
113
+ # Number of instances
114
+ N = boxes.shape[0]
115
+ if not N:
116
+ print("\n*** No instances to display *** \n")
117
+ else:
118
+ assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0]
119
+
120
+ # If no axis is passed, create one and automatically call show()
121
+ auto_show = False
122
+ if not ax:
123
+ _, ax = plt.subplots(1, figsize=figsize)
124
+ auto_show = True
125
+
126
+ # Generate random colors
127
+ colors = colors or random_colors(N)
128
+
129
+ # Show area outside image boundaries.
130
+ height, width = image.shape[:2]
131
+ ax.set_ylim(height + 10, -10)
132
+ ax.set_xlim(-10, width + 10)
133
+ ax.axis("off")
134
+ ax.set_title(title)
135
+
136
+ masked_image = image.astype(np.uint32).copy()
137
+ for i in range(N):
138
+ color = colors[i]
139
+
140
+ # Bounding box
141
+ if not np.any(boxes[i]):
142
+ # Skip this instance. Has no bbox. Likely lost in image cropping.
143
+ continue
144
+ y1, x1, y2, x2 = boxes[i]
145
+ if show_bbox:
146
+ p = patches.Rectangle(
147
+ (x1, y1),
148
+ x2 - x1,
149
+ y2 - y1,
150
+ linewidth=2,
151
+ alpha=0.7,
152
+ linestyle="dashed",
153
+ edgecolor=color,
154
+ facecolor="none",
155
+ )
156
+ ax.add_patch(p)
157
+
158
+ # Label
159
+ if not captions:
160
+ class_id = class_ids[i]
161
+ score = scores[i] if scores is not None else None
162
+ label = class_names[class_id]
163
+ caption = "{} {:.3f}".format(label, score) if score else label
164
+ else:
165
+ caption = captions[i]
166
+ ax.text(x1, y1 + 8, caption, color="w", size=11, backgroundcolor="none")
167
+
168
+ # Mask
169
+ mask = masks[:, :, i]
170
+ if show_mask:
171
+ masked_image = apply_mask(masked_image, mask, color)
172
+
173
+ # Mask Polygon
174
+ # Pad to ensure proper polygons for masks that touch image edges.
175
+ padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8)
176
+ padded_mask[1:-1, 1:-1] = mask
177
+ contours = find_contours(padded_mask, 0.5)
178
+ for verts in contours:
179
+ # Subtract the padding and flip (y, x) to (x, y)
180
+ verts = np.fliplr(verts) - 1
181
+ p = Polygon(verts, facecolor="none", edgecolor=color)
182
+ ax.add_patch(p)
183
+
184
+ # ax.imshow(masked_image.astype(np.uint8))
185
+
186
+ if auto_show:
187
+ plt.show()
188
+
189
+ return masked_image.astype(np.uint8)
190
+
191
+
192
+ def display_differences(
193
+ image,
194
+ gt_box,
195
+ gt_class_id,
196
+ gt_mask,
197
+ pred_box,
198
+ pred_class_id,
199
+ pred_score,
200
+ pred_mask,
201
+ class_names,
202
+ title="",
203
+ ax=None,
204
+ show_mask=True,
205
+ show_box=True,
206
+ iou_threshold=0.5,
207
+ score_threshold=0.5,
208
+ ):
209
+ """Display ground truth and prediction instances on the same image."""
210
+ # Match predictions to ground truth
211
+ gt_match, pred_match, overlaps = utils.compute_matches(
212
+ gt_box,
213
+ gt_class_id,
214
+ gt_mask,
215
+ pred_box,
216
+ pred_class_id,
217
+ pred_score,
218
+ pred_mask,
219
+ iou_threshold=iou_threshold,
220
+ score_threshold=score_threshold,
221
+ )
222
+ # Ground truth = green. Predictions = red
223
+ colors = [(0, 1, 0, 0.8)] * len(gt_match) + [(1, 0, 0, 1)] * len(pred_match)
224
+ # Concatenate GT and predictions
225
+ class_ids = np.concatenate([gt_class_id, pred_class_id])
226
+ scores = np.concatenate([np.zeros([len(gt_match)]), pred_score])
227
+ boxes = np.concatenate([gt_box, pred_box])
228
+ masks = np.concatenate([gt_mask, pred_mask], axis=-1)
229
+ # Captions per instance show score/IoU
230
+ captions = ["" for m in gt_match] + [
231
+ "{:.2f} / {:.2f}".format(
232
+ pred_score[i],
233
+ (
234
+ overlaps[i, int(pred_match[i])]
235
+ if pred_match[i] > -1
236
+ else overlaps[i].max()
237
+ ),
238
+ )
239
+ for i in range(len(pred_match))
240
+ ]
241
+ # Set title if not provided
242
+ title = (
243
+ title or "Ground Truth and Detections\n GT=green, pred=red, captions: score/IoU"
244
+ )
245
+ # Display
246
+ display_instances(
247
+ image,
248
+ boxes,
249
+ masks,
250
+ class_ids,
251
+ class_names,
252
+ scores,
253
+ ax=ax,
254
+ show_bbox=show_box,
255
+ show_mask=show_mask,
256
+ colors=colors,
257
+ captions=captions,
258
+ title=title,
259
+ )
260
+
261
+
262
+ def draw_rois(image, rois, refined_rois, mask, class_ids, class_names, limit=10):
263
+ """
264
+ anchors: [n, (y1, x1, y2, x2)] list of anchors in image coordinates.
265
+ proposals: [n, 4] the same anchors but refined to fit objects better.
266
+ """
267
+ masked_image = image.copy()
268
+
269
+ # Pick random anchors in case there are too many.
270
+ ids = np.arange(rois.shape[0], dtype=np.int32)
271
+ ids = np.random.choice(ids, limit, replace=False) if ids.shape[0] > limit else ids
272
+
273
+ fig, ax = plt.subplots(1, figsize=(12, 12))
274
+ if rois.shape[0] > limit:
275
+ plt.title("Showing {} random ROIs out of {}".format(len(ids), rois.shape[0]))
276
+ else:
277
+ plt.title("{} ROIs".format(len(ids)))
278
+
279
+ # Show area outside image boundaries.
280
+ ax.set_ylim(image.shape[0] + 20, -20)
281
+ ax.set_xlim(-50, image.shape[1] + 20)
282
+ ax.axis("off")
283
+
284
+ for i, id in enumerate(ids):
285
+ color = np.random.rand(3)
286
+ class_id = class_ids[id]
287
+ # ROI
288
+ y1, x1, y2, x2 = rois[id]
289
+ p = patches.Rectangle(
290
+ (x1, y1),
291
+ x2 - x1,
292
+ y2 - y1,
293
+ linewidth=2,
294
+ edgecolor=color if class_id else "gray",
295
+ facecolor="none",
296
+ linestyle="dashed",
297
+ )
298
+ ax.add_patch(p)
299
+ # Refined ROI
300
+ if class_id:
301
+ ry1, rx1, ry2, rx2 = refined_rois[id]
302
+ p = patches.Rectangle(
303
+ (rx1, ry1),
304
+ rx2 - rx1,
305
+ ry2 - ry1,
306
+ linewidth=2,
307
+ edgecolor=color,
308
+ facecolor="none",
309
+ )
310
+ ax.add_patch(p)
311
+ # Connect the top-left corners of the anchor and proposal for easy visualization
312
+ ax.add_line(lines.Line2D([x1, rx1], [y1, ry1], color=color))
313
+
314
+ # Label
315
+ label = class_names[class_id]
316
+ ax.text(
317
+ rx1,
318
+ ry1 + 8,
319
+ "{}".format(label),
320
+ color="w",
321
+ size=11,
322
+ backgroundcolor="none",
323
+ )
324
+
325
+ # Mask
326
+ m = utils.unmold_mask(mask[id], rois[id][:4].astype(np.int32), image.shape)
327
+ masked_image = apply_mask(masked_image, m, color)
328
+
329
+ ax.imshow(masked_image)
330
+
331
+ # Print stats
332
+ print("Positive ROIs: ", class_ids[class_ids > 0].shape[0])
333
+ print("Negative ROIs: ", class_ids[class_ids == 0].shape[0])
334
+ print(
335
+ "Positive Ratio: {:.2f}".format(
336
+ class_ids[class_ids > 0].shape[0] / class_ids.shape[0]
337
+ )
338
+ )
339
+
340
+
341
+ # TODO: Replace with matplotlib equivalent?
342
+ def draw_box(image, box, color):
343
+ """Draw 3-pixel width bounding boxes on the given image array.
344
+ color: list of 3 int values for RGB.
345
+ """
346
+ y1, x1, y2, x2 = box
347
+ image[y1 : y1 + 2, x1:x2] = color
348
+ image[y2 : y2 + 2, x1:x2] = color
349
+ image[y1:y2, x1 : x1 + 2] = color
350
+ image[y1:y2, x2 : x2 + 2] = color
351
+ return image
352
+
353
+
354
+ def display_top_masks(image, mask, class_ids, class_names, limit=4):
355
+ """Display the given image and the top few class masks."""
356
+ to_display = []
357
+ titles = []
358
+ to_display.append(image)
359
+ titles.append("H x W={}x{}".format(image.shape[0], image.shape[1]))
360
+ # Pick top prominent classes in this image
361
+ unique_class_ids = np.unique(class_ids)
362
+ mask_area = [
363
+ np.sum(mask[:, :, np.where(class_ids == i)[0]]) for i in unique_class_ids
364
+ ]
365
+ top_ids = [
366
+ v[0]
367
+ for v in sorted(
368
+ zip(unique_class_ids, mask_area), key=lambda r: r[1], reverse=True
369
+ )
370
+ if v[1] > 0
371
+ ]
372
+ # Generate images and titles
373
+ for i in range(limit):
374
+ class_id = top_ids[i] if i < len(top_ids) else -1
375
+ # Pull masks of instances belonging to the same class.
376
+ m = mask[:, :, np.where(class_ids == class_id)[0]]
377
+ m = np.sum(m * np.arange(1, m.shape[-1] + 1), -1)
378
+ to_display.append(m)
379
+ titles.append(class_names[class_id] if class_id != -1 else "-")
380
+ display_images(to_display, titles=titles, cols=limit + 1, cmap="Blues_r")
381
+
382
+
383
+ def plot_precision_recall(AP, precisions, recalls):
384
+ """Draw the precision-recall curve.
385
+
386
+ AP: Average precision at IoU >= 0.5
387
+ precisions: list of precision values
388
+ recalls: list of recall values
389
+ """
390
+ # Plot the Precision-Recall curve
391
+ _, ax = plt.subplots(1)
392
+ ax.set_title("Precision-Recall Curve. AP@50 = {:.3f}".format(AP))
393
+ ax.set_ylim(0, 1.1)
394
+ ax.set_xlim(0, 1.1)
395
+ _ = ax.plot(recalls, precisions)
396
+
397
+
398
+ def plot_overlaps(
399
+ gt_class_ids, pred_class_ids, pred_scores, overlaps, class_names, threshold=0.5
400
+ ):
401
+ """Draw a grid showing how ground truth objects are classified.
402
+ gt_class_ids: [N] int. Ground truth class IDs
403
+ pred_class_id: [N] int. Predicted class IDs
404
+ pred_scores: [N] float. The probability scores of predicted classes
405
+ overlaps: [pred_boxes, gt_boxes] IoU overlaps of predictions and GT boxes.
406
+ class_names: list of all class names in the dataset
407
+ threshold: Float. The prediction probability required to predict a class
408
+ """
409
+ gt_class_ids = gt_class_ids[gt_class_ids != 0]
410
+ pred_class_ids = pred_class_ids[pred_class_ids != 0]
411
+
412
+ plt.figure(figsize=(12, 10))
413
+ plt.imshow(overlaps, interpolation="nearest", cmap=plt.cm.Blues)
414
+ plt.yticks(
415
+ np.arange(len(pred_class_ids)),
416
+ [
417
+ "{} ({:.2f})".format(class_names[int(id)], pred_scores[i])
418
+ for i, id in enumerate(pred_class_ids)
419
+ ],
420
+ )
421
+ plt.xticks(
422
+ np.arange(len(gt_class_ids)),
423
+ [class_names[int(id)] for id in gt_class_ids],
424
+ rotation=90,
425
+ )
426
+
427
+ thresh = overlaps.max() / 2.0
428
+ for i, j in itertools.product(range(overlaps.shape[0]), range(overlaps.shape[1])):
429
+ text = ""
430
+ if overlaps[i, j] > threshold:
431
+ text = "match" if gt_class_ids[j] == pred_class_ids[i] else "wrong"
432
+ color = (
433
+ "white"
434
+ if overlaps[i, j] > thresh
435
+ else "black"
436
+ if overlaps[i, j] > 0
437
+ else "grey"
438
+ )
439
+ plt.text(
440
+ j,
441
+ i,
442
+ "{:.3f}\n{}".format(overlaps[i, j], text),
443
+ horizontalalignment="center",
444
+ verticalalignment="center",
445
+ fontsize=9,
446
+ color=color,
447
+ )
448
+
449
+ plt.tight_layout()
450
+ plt.xlabel("Ground Truth")
451
+ plt.ylabel("Predictions")
452
+
453
+
454
+ def draw_boxes(
455
+ image,
456
+ boxes=None,
457
+ refined_boxes=None,
458
+ masks=None,
459
+ captions=None,
460
+ visibilities=None,
461
+ title="",
462
+ ax=None,
463
+ ):
464
+ """Draw bounding boxes and segmentation masks with different
465
+ customizations.
466
+
467
+ boxes: [N, (y1, x1, y2, x2, class_id)] in image coordinates.
468
+ refined_boxes: Like boxes, but draw with solid lines to show
469
+ that they're the result of refining 'boxes'.
470
+ masks: [N, height, width]
471
+ captions: List of N titles to display on each box
472
+ visibilities: (optional) List of values of 0, 1, or 2. Determine how
473
+ prominent each bounding box should be.
474
+ title: An optional title to show over the image
475
+ ax: (optional) Matplotlib axis to draw on.
476
+ """
477
+ # Number of boxes
478
+ assert boxes is not None or refined_boxes is not None
479
+ N = boxes.shape[0] if boxes is not None else refined_boxes.shape[0]
480
+
481
+ # Matplotlib Axis
482
+ if not ax:
483
+ _, ax = plt.subplots(1, figsize=(12, 12))
484
+
485
+ # Generate random colors
486
+ colors = random_colors(N)
487
+
488
+ # Show area outside image boundaries.
489
+ margin = image.shape[0] // 10
490
+ ax.set_ylim(image.shape[0] + margin, -margin)
491
+ ax.set_xlim(-margin, image.shape[1] + margin)
492
+ ax.axis("off")
493
+
494
+ ax.set_title(title)
495
+
496
+ masked_image = image.astype(np.uint32).copy()
497
+ for i in range(N):
498
+ # Box visibility
499
+ visibility = visibilities[i] if visibilities is not None else 1
500
+ if visibility == 0:
501
+ color = "gray"
502
+ style = "dotted"
503
+ alpha = 0.5
504
+ elif visibility == 1:
505
+ color = colors[i]
506
+ style = "dotted"
507
+ alpha = 1
508
+ elif visibility == 2:
509
+ color = colors[i]
510
+ style = "solid"
511
+ alpha = 1
512
+
513
+ # Boxes
514
+ if boxes is not None:
515
+ if not np.any(boxes[i]):
516
+ # Skip this instance. Has no bbox. Likely lost in cropping.
517
+ continue
518
+ y1, x1, y2, x2 = boxes[i]
519
+ p = patches.Rectangle(
520
+ (x1, y1),
521
+ x2 - x1,
522
+ y2 - y1,
523
+ linewidth=2,
524
+ alpha=alpha,
525
+ linestyle=style,
526
+ edgecolor=color,
527
+ facecolor="none",
528
+ )
529
+ ax.add_patch(p)
530
+
531
+ # Refined boxes
532
+ if refined_boxes is not None and visibility > 0:
533
+ ry1, rx1, ry2, rx2 = refined_boxes[i].astype(np.int32)
534
+ p = patches.Rectangle(
535
+ (rx1, ry1),
536
+ rx2 - rx1,
537
+ ry2 - ry1,
538
+ linewidth=2,
539
+ edgecolor=color,
540
+ facecolor="none",
541
+ )
542
+ ax.add_patch(p)
543
+ # Connect the top-left corners of the anchor and proposal
544
+ if boxes is not None:
545
+ ax.add_line(lines.Line2D([x1, rx1], [y1, ry1], color=color))
546
+
547
+ # Captions
548
+ if captions is not None:
549
+ caption = captions[i]
550
+ # If there are refined boxes, display captions on them
551
+ if refined_boxes is not None:
552
+ y1, x1, y2, x2 = ry1, rx1, ry2, rx2
553
+ ax.text(
554
+ x1,
555
+ y1,
556
+ caption,
557
+ size=11,
558
+ verticalalignment="top",
559
+ color="w",
560
+ backgroundcolor="none",
561
+ bbox={"facecolor": color, "alpha": 0.5, "pad": 2, "edgecolor": "none"},
562
+ )
563
+
564
+ # Masks
565
+ if masks is not None:
566
+ mask = masks[:, :, i]
567
+ masked_image = apply_mask(masked_image, mask, color)
568
+ # Mask Polygon
569
+ # Pad to ensure proper polygons for masks that touch image edges.
570
+ padded_mask = np.zeros(
571
+ (mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8
572
+ )
573
+ padded_mask[1:-1, 1:-1] = mask
574
+ contours = find_contours(padded_mask, 0.5)
575
+ for verts in contours:
576
+ # Subtract the padding and flip (y, x) to (x, y)
577
+ verts = np.fliplr(verts) - 1
578
+ p = Polygon(verts, facecolor="none", edgecolor=color)
579
+ ax.add_patch(p)
580
+ ax.imshow(masked_image.astype(np.uint8))
581
+
582
+
583
+ def display_table(table):
584
+ """Display values in a table format.
585
+ table: an iterable of rows, and each row is an iterable of values.
586
+ """
587
+ html = ""
588
+ for row in table:
589
+ row_html = ""
590
+ for col in row:
591
+ row_html += "<td>{:40}</td>".format(str(col))
592
+ html += "<tr>" + row_html + "</tr>"
593
+ html = "<table>" + html + "</table>"
594
+ IPython.display.display(IPython.display.HTML(html))
595
+
596
+
597
+ def display_weight_stats(model):
598
+ """Scans all the weights in the model and returns a list of tuples
599
+ that contain stats about each weight.
600
+ """
601
+ layers = model.get_trainable_layers()
602
+ table = [["WEIGHT NAME", "SHAPE", "MIN", "MAX", "STD"]]
603
+ for l in layers:
604
+ weight_values = l.get_weights() # list of Numpy arrays
605
+ weight_tensors = l.weights # list of TF tensors
606
+ for i, w in enumerate(weight_values):
607
+ weight_name = weight_tensors[i].name
608
+ # Detect problematic layers. Exclude biases of conv layers.
609
+ alert = ""
610
+ if w.min() == w.max() and not (l.__class__.__name__ == "Conv2D" and i == 1):
611
+ alert += "<span style='color:red'>*** dead?</span>"
612
+ if np.abs(w.min()) > 1000 or np.abs(w.max()) > 1000:
613
+ alert += "<span style='color:red'>*** Overflow?</span>"
614
+ # Add row
615
+ table.append(
616
+ [
617
+ weight_name + alert,
618
+ str(w.shape),
619
+ "{:+9.4f}".format(w.min()),
620
+ "{:+10.4f}".format(w.max()),
621
+ "{:+9.4f}".format(w.std()),
622
+ ]
623
+ )
624
+ display_table(table)
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tensorflow==1.14.0
2
+ keras==2.0.8
3
+ protobuf==3.20.1
4
+ gradio==3.0.15
5
+ gdown==4.4.0
6
+ numpy
7
+ scipy
8
+ Pillow
9
+ cython
10
+ matplotlib
11
+ scikit-image
12
+ opencv-python
13
+ h5py==2.10.0
14
+ imgaug
15
+ IPython[all]
setup.cfg ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [metadata]
2
+ description-file = README.md
3
+ license-file = LICENSE
4
+ requirements-file = requirements.txt
setup.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The build/compilations setup
3
+
4
+ >> pip install -r requirements.txt
5
+ >> python setup.py install
6
+ """
7
+ import logging
8
+
9
+ import pip
10
+ import pkg_resources
11
+
12
+ try:
13
+ from setuptools import setup
14
+ except ImportError:
15
+ from distutils.core import setup
16
+
17
+
18
+ def _parse_requirements(file_path):
19
+ pip_ver = pkg_resources.get_distribution("pip").version
20
+ pip_version = list(map(int, pip_ver.split(".")[:2]))
21
+ if pip_version >= [6, 0]:
22
+ raw = pip.req.parse_requirements(file_path, session=pip.download.PipSession())
23
+ else:
24
+ raw = pip.req.parse_requirements(file_path)
25
+ return [str(i.req) for i in raw]
26
+
27
+
28
+ # parse_requirements() returns generator of pip.req.InstallRequirement objects
29
+ try:
30
+ install_reqs = _parse_requirements("requirements.txt")
31
+ except Exception:
32
+ logging.warning("Fail load requirements file, so using default ones.")
33
+ install_reqs = []
34
+
35
+ setup(
36
+ name="mask-rcnn",
37
+ version="2.1",
38
+ url="https://github.com/matterport/Mask_RCNN",
39
+ author="Matterport",
40
+ author_email="[email protected]",
41
+ license="MIT",
42
+ description="Mask R-CNN for object detection and instance segmentation",
43
+ packages=["mrcnn"],
44
+ install_requires=install_reqs,
45
+ include_package_data=True,
46
+ python_requires=">=3.4",
47
+ long_description="""This is an implementation of Mask R-CNN on Python 3, Keras, and TensorFlow.
48
+ The model generates bounding boxes and segmentation masks for each instance of an object in the image.
49
+ It's based on Feature Pyramid Network (FPN) and a ResNet101 backbone.""",
50
+ classifiers=[
51
+ "Development Status :: 5 - Production/Stable",
52
+ "Environment :: Console",
53
+ "Intended Audience :: Developers",
54
+ "Intended Audience :: Information Technology",
55
+ "Intended Audience :: Education",
56
+ "Intended Audience :: Science/Research",
57
+ "License :: OSI Approved :: MIT License",
58
+ "Natural Language :: English",
59
+ "Operating System :: OS Independent",
60
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
61
+ "Topic :: Scientific/Engineering :: Image Recognition",
62
+ "Topic :: Scientific/Engineering :: Visualization",
63
+ "Topic :: Scientific/Engineering :: Image Segmentation",
64
+ "Programming Language :: Python :: 3.4",
65
+ "Programming Language :: Python :: 3.5",
66
+ "Programming Language :: Python :: 3.6",
67
+ ],
68
+ keywords="image instance segmentation object detection mask rcnn r-cnn tensorflow keras",
69
+ )
utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+
3
+
4
+ # from official repo
5
+ def get_ax(rows=1, cols=1, size=7):
6
+ """Return a Matplotlib Axes array to be used in
7
+ all visualizations in the notebook. Provide a
8
+ central point to control graph sizes.
9
+
10
+ Adjust the size attribute to control how big to render images
11
+ """
12
+ _, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
13
+ return ax