# ------------ tackle some noisy warning
import os
import warnings
def warn(*args, **kwargs):
pass
warnings.warn = warn
warnings.filterwarnings("ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import random
import gdown
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from PIL import Image
import mrcnn.model as modellib
from config import WheatDetectorConfig
from config import WheatInferenceConfig
from mrcnn import utils
from mrcnn import visualize
from mrcnn.model import log
from utils import get_ax
# for reproducibility
def seed_all(SEED):
random.seed(SEED)
np.random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
ORIG_SIZE = 1024
seed_all(42)
config = WheatDetectorConfig()
inference_config = WheatInferenceConfig()
def get_model_weight(model_id):
"""Get the trained weights."""
if not os.path.exists("model.h5"):
model_weight = gdown.download(id=model_id, quiet=False)
else:
model_weight = "model.h5"
return model_weight
def get_model():
"""Get the model."""
model = modellib.MaskRCNN(mode="inference", config=inference_config, model_dir="./")
return model
def load_model(model_id):
"""Load trained model."""
weight = get_model_weight(model_id)
model = get_model()
model.load_weights(weight, by_name=True)
return model
def prepare_image(image):
"""Prepare incoming sample."""
image = image[:, :, ::-1]
resize_factor = ORIG_SIZE / config.IMAGE_SHAPE[0]
# If grayscale. Convert to RGB for consistency.
if len(image.shape) != 3 or image.shape[2] != 3:
image = np.stack((image,) * 3, -1)
resized_image, window, scale, padding, crop = utils.resize_image(
image,
min_dim=config.IMAGE_MIN_DIM,
min_scale=config.IMAGE_MIN_SCALE,
max_dim=config.IMAGE_MAX_DIM,
mode=config.IMAGE_RESIZE_MODE,
)
return resized_image
def predict_fn(image):
image = prepare_image(image)
model = load_model(model_id="1k4_WGBAUJCPbkkHkvtscX2jufTqETNYd")
results = model.detect([image])
r = results[0]
class_names = ["Wheat"] * len(r["rois"])
image = visualize.display_instances(
image,
r["rois"],
r["masks"],
r["class_ids"],
class_names,
r["scores"],
ax=get_ax(),
title="Predictions",
)
return image[:, :, ::-1]
title="Global Wheat Detection with Mask-RCNN Model"
description="Model: Mask-RCNN. Backbone: ResNet-101. Trained on: Global Wheat Detection Dataset (Kaggle). The code is written in Keras (TensorFlow 1.14)
. One can run the full code on Kaggle: [Keras]:Global Wheat Detection with Mask-RCNN"
article = "
The model received 0.6449 and 0.5675 mAP (0.5:0.75:0.05) on the public and private test dataset respectively. The above examples are from test dataset without ground truth bounding box. Details: Global Wheat Dataset
" iface = gr.Interface( fn=predict_fn, inputs=gr.Image(label="Input Image"), outputs=gr.Image(label="Prediction"), title=title, description=description, article=article, examples=[ ["examples/2fd875eaa.jpg"], ["examples/51b3e36ab.jpg"], ["examples/51f1be19e.jpg"], ["examples/53f253011.jpg"], ["examples/348a992bb.jpg"], ["examples/796707dd7.jpg"], ["examples/aac893a91.jpg"], ["examples/cb8d261a3.jpg"], ["examples/cc3532ff6.jpg"], ["examples/f5a1f0358.jpg"], ], ) iface.launch(share=True)