# %% import gradio as gr import tensorflow as tf import cv2 import os model_folder = 'model' destination = model_folder repo_url = "https://huggingface.co/RandomCatLover/plants_disease" if not os.path.exists(destination): import subprocess #repo_url = os.getenv("GIT_CORE") command = f'git clone {repo_url} {destination}' try: subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True)#, env=env) print('Repository cloned successfully.') except subprocess.CalledProcessError as e: print(f'Error cloning repository: {e.output.decode()}') destination = 'explainer_tf_mobilenetv2' if not os.path.exists(destination): import subprocess repo_url = os.getenv("GIT_CORE") command = f'git clone {repo_url}' try: subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True)#, env=env) print('Repository cloned successfully.') except subprocess.CalledProcessError as e: print(f'Error cloning repository: {e.output.decode()}') from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 as explainer # import mobilenetv2.explainer # from tensorflow.keras.applications.mobilenet_v2.explainer import Explainer as explainer # %% with open(f'{model_folder}/labels.txt', 'r') as f: labels = f.read().split('\n') # model = tf.saved_model.load(f'{model_folder}/last_layer.hdf5') model = tf.keras.models.load_model(f'{model_folder}/last_layer.hdf5') #model = tf.keras.models.load_model(f'{model_folder}/MobileNetV2_last_layer.hdf5') # %% def classify_image(inp): inp = cv2.resize(inp, (224,224,)) inp = inp.reshape((-1, 224, 224, 3)) inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp) prediction = model.predict(inp).flatten() print(prediction) confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))} return confidences def explainer_wrapper(inp): return explainer(inp, model) with gr.Blocks() as demo: with gr.Column(): with gr.Row(): with gr.Column(): image = gr.inputs.Image(shape=(224, 224)) with gr.Row(): classify = gr.Button("Classify") interpret = gr.Button("Interpret") with gr.Column(): label = gr.outputs.Label(num_top_classes=3) interpretation = gr.Plot(label="Interpretation") # interpretation = gr.outputs.Image(type="numpy", label="Interpretation") gr.Examples(["TomatoHealthy2.jpg", "TomatoYellowCurlVirus3.jpg", "AppleCedarRust3.jpg"], inputs=[image],) classify.click(classify_image, image, label, queue=True) interpret.click(explainer_wrapper, image, interpretation, queue=True) demo.queue(concurrency_count=3).launch() #%% # gr.Interface(fn=classify_image, # inputs=gr.Image(shape=(224, 224)), # outputs=gr.Label(num_top_classes=3), # examples=["TomatoHealthy2.jpg", "TomatoYellowCurlVirus3.jpg", "AppleCedarRust3.jpg"]).launch()