# ------------ tackle some noisy warning import os import warnings def warn(*args, **kwargs): pass warnings.warn = warn warnings.filterwarnings("ignore") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import random import gdown import gradio as gr import matplotlib.pyplot as plt import numpy as np import tensorflow as tf from PIL import Image import mrcnn.model as modellib from config import WheatDetectorConfig from config import WheatInferenceConfig from mrcnn import utils from mrcnn import visualize from mrcnn.model import log from utils import get_ax # for reproducibility def seed_all(SEED): random.seed(SEED) np.random.seed(SEED) os.environ["PYTHONHASHSEED"] = str(SEED) ORIG_SIZE = 1024 seed_all(42) config = WheatDetectorConfig() inference_config = WheatInferenceConfig() def get_model_weight(model_id): """Get the trained weights.""" if not os.path.exists("model.h5"): model_weight = gdown.download(id=model_id, quiet=False) else: model_weight = "model.h5" return model_weight def get_model(): """Get the model.""" model = modellib.MaskRCNN(mode="inference", config=inference_config, model_dir="./") return model def load_model(model_id): """Load trained model.""" weight = get_model_weight(model_id) model = get_model() model.load_weights(weight, by_name=True) return model def prepare_image(image): """Prepare incoming sample.""" image = image[:, :, ::-1] resize_factor = ORIG_SIZE / config.IMAGE_SHAPE[0] # If grayscale. Convert to RGB for consistency. if len(image.shape) != 3 or image.shape[2] != 3: image = np.stack((image,) * 3, -1) resized_image, window, scale, padding, crop = utils.resize_image( image, min_dim=config.IMAGE_MIN_DIM, min_scale=config.IMAGE_MIN_SCALE, max_dim=config.IMAGE_MAX_DIM, mode=config.IMAGE_RESIZE_MODE, ) return resized_image def predict_fn(image): image = prepare_image(image) model = load_model(model_id="1k4_WGBAUJCPbkkHkvtscX2jufTqETNYd") results = model.detect([image]) r = results[0] class_names = ["Wheat"] * len(r["rois"]) image = visualize.display_instances( image, r["rois"], r["masks"], r["class_ids"], class_names, r["scores"], ax=get_ax(), title="Predictions", ) return image[:, :, ::-1] title="Global Wheat Detection with Mask-RCNN Model" description="Model: Mask-RCNN. Backbone: ResNet-101. Trained on: Global Wheat Detection Dataset (Kaggle).
The code is written in Keras (TensorFlow 1.14). One can run the full code on Kaggle: [Keras]:Global Wheat Detection with Mask-RCNN" article = "

The model received 0.6449 and 0.5675 mAP (0.5:0.75:0.05) on the public and private test dataset respectively. The above examples are from test dataset without ground truth bounding box. Details: Global Wheat Dataset

" iface = gr.Interface( fn=predict_fn, inputs=gr.inputs.Image(label="Input Image"), outputs=gr.outputs.Image(label="Prediction"), title=title, description=description, article=article, examples=[ ["examples/2fd875eaa.jpg"], ["examples/51b3e36ab.jpg"], ["examples/51f1be19e.jpg"], ["examples/53f253011.jpg"], ["examples/348a992bb.jpg"], ["examples/796707dd7.jpg"], ["examples/aac893a91.jpg"], ["examples/cb8d261a3.jpg"], ["examples/cc3532ff6.jpg"], ["examples/f5a1f0358.jpg"], ], ) iface.launch()