Spaces:
Build error
Build error
# ------------ tackle some noisy warning | |
import os | |
import warnings | |
def warn(*args, **kwargs): | |
pass | |
warnings.warn = warn | |
warnings.filterwarnings("ignore") | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
import random | |
import gdown | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import tensorflow as tf | |
from PIL import Image | |
import mrcnn.model as modellib | |
from config import WheatDetectorConfig | |
from config import WheatInferenceConfig | |
from mrcnn import utils | |
from mrcnn import visualize | |
from mrcnn.model import log | |
from utils import get_ax | |
# for reproducibility | |
def seed_all(SEED): | |
random.seed(SEED) | |
np.random.seed(SEED) | |
os.environ["PYTHONHASHSEED"] = str(SEED) | |
ORIG_SIZE = 1024 | |
seed_all(42) | |
config = WheatDetectorConfig() | |
inference_config = WheatInferenceConfig() | |
def get_model_weight(model_id): | |
"""Get the trained weights.""" | |
if not os.path.exists("model.h5"): | |
model_weight = gdown.download(id=model_id, quiet=False) | |
else: | |
model_weight = "model.h5" | |
return model_weight | |
def get_model(): | |
"""Get the model.""" | |
model = modellib.MaskRCNN(mode="inference", config=inference_config, model_dir="./") | |
return model | |
def load_model(model_id): | |
"""Load trained model.""" | |
weight = get_model_weight(model_id) | |
model = get_model() | |
model.load_weights(weight, by_name=True) | |
return model | |
def prepare_image(image): | |
"""Prepare incoming sample.""" | |
image = image[:, :, ::-1] | |
resize_factor = ORIG_SIZE / config.IMAGE_SHAPE[0] | |
# If grayscale. Convert to RGB for consistency. | |
if len(image.shape) != 3 or image.shape[2] != 3: | |
image = np.stack((image,) * 3, -1) | |
resized_image, window, scale, padding, crop = utils.resize_image( | |
image, | |
min_dim=config.IMAGE_MIN_DIM, | |
min_scale=config.IMAGE_MIN_SCALE, | |
max_dim=config.IMAGE_MAX_DIM, | |
mode=config.IMAGE_RESIZE_MODE, | |
) | |
return resized_image | |
def predict_fn(image): | |
image = prepare_image(image) | |
model = load_model(model_id="1k4_WGBAUJCPbkkHkvtscX2jufTqETNYd") | |
results = model.detect([image]) | |
r = results[0] | |
class_names = ["Wheat"] * len(r["rois"]) | |
image = visualize.display_instances( | |
image, | |
r["rois"], | |
r["masks"], | |
r["class_ids"], | |
class_names, | |
r["scores"], | |
ax=get_ax(), | |
title="Predictions", | |
) | |
return image[:, :, ::-1] | |
title="Global Wheat Detection with Mask-RCNN Model" | |
description="<strong>Model</strong>: Mask-RCNN. <strong>Backbone</strong>: ResNet-101. Trained on: <a href='https://www.kaggle.com/competitions/global-wheat-detection/overview'>Global Wheat Detection Dataset (Kaggle)</a>. </br>The code is written in <code>Keras (TensorFlow 1.14)</code>. One can run the full code on Kaggle: <a href='https://www.kaggle.com/code/ipythonx/keras-global-wheat-detection-with-mask-rcnn'>[Keras]:Global Wheat Detection with Mask-RCNN</a>" | |
article = "<p>The model received <strong>0.6449</strong> and <strong>0.5675</strong> mAP (0.5:0.75:0.05) on the public and private test dataset respectively. The above examples are from test dataset without ground truth bounding box. Details: <a href='https://www.kaggle.com/competitions/global-wheat-detection/data'>Global Wheat Dataset</a></p>" | |
iface = gr.Interface( | |
fn=predict_fn, | |
inputs=gr.Image(label="Input Image"), | |
outputs=gr.Image(label="Prediction"), | |
title=title, | |
description=description, | |
article=article, | |
examples=[ | |
["examples/2fd875eaa.jpg"], | |
["examples/51b3e36ab.jpg"], | |
["examples/51f1be19e.jpg"], | |
["examples/53f253011.jpg"], | |
["examples/348a992bb.jpg"], | |
["examples/796707dd7.jpg"], | |
["examples/aac893a91.jpg"], | |
["examples/cb8d261a3.jpg"], | |
["examples/cc3532ff6.jpg"], | |
["examples/f5a1f0358.jpg"], | |
], | |
) | |
iface.launch(share=True) | |