|
|
|
import gradio as gr |
|
import tensorflow as tf |
|
import cv2 |
|
import os |
|
|
|
model_folder = 'model' |
|
destination = model_folder |
|
repo_url = "https://huggingface.co/RandomCatLover/plants_disease" |
|
|
|
if not os.path.exists(destination): |
|
import subprocess |
|
|
|
command = f'git clone {repo_url} {destination}' |
|
try: |
|
subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True) |
|
print('Repository cloned successfully.') |
|
except subprocess.CalledProcessError as e: |
|
print(f'Error cloning repository: {e.output.decode()}') |
|
|
|
destination = 'explainer_tf_mobilenetv2' |
|
if not os.path.exists(destination): |
|
import subprocess |
|
repo_url = os.getenv("GIT_CORE") |
|
command = f'git clone {repo_url}' |
|
try: |
|
subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True) |
|
print('Repository cloned successfully.') |
|
except subprocess.CalledProcessError as e: |
|
print(f'Error cloning repository: {e.output.decode()}') |
|
|
|
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 as explainer |
|
|
|
|
|
|
|
|
|
with open(f'{model_folder}/labels.txt', 'r') as f: |
|
labels = f.read().split('\n') |
|
|
|
|
|
model = tf.keras.models.load_model(f'{model_folder}/last_layer.hdf5') |
|
|
|
|
|
def classify_image(inp): |
|
inp = cv2.resize(inp, (224,224,)) |
|
inp = inp.reshape((-1, 224, 224, 3)) |
|
inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp) |
|
prediction = model.predict(inp).flatten() |
|
print(prediction) |
|
confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))} |
|
return confidences |
|
|
|
def explainer_wrapper(inp): |
|
return explainer(inp, model) |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
image = gr.inputs.Image(shape=(224, 224)) |
|
with gr.Row(): |
|
classify = gr.Button("Classify") |
|
interpret = gr.Button("Interpret") |
|
with gr.Column(): |
|
label = gr.outputs.Label(num_top_classes=3) |
|
interpretation = gr.Plot(label="Interpretation") |
|
|
|
gr.Examples(["TomatoHealthy2.jpg", "TomatoYellowCurlVirus3.jpg", "AppleCedarRust3.jpg"], |
|
inputs=[image],) |
|
classify.click(classify_image, image, label, queue=True) |
|
interpret.click(explainer_wrapper, image, interpretation, queue=True) |
|
|
|
|
|
demo.queue(concurrency_count=3).launch() |
|
|
|
|
|
|
|
|
|
|
|
|