RandomCatLover commited on
Commit
e092e8d
·
1 Parent(s): 712c19c
Files changed (1) hide show
  1. app.py +41 -6
app.py CHANGED
@@ -1,7 +1,6 @@
1
  # %%
2
  import gradio as gr
3
  import tensorflow as tf
4
- import numpy as np
5
  import cv2
6
  import os
7
 
@@ -18,12 +17,25 @@ if not os.path.exists(destination):
18
  print('Repository cloned successfully.')
19
  except subprocess.CalledProcessError as e:
20
  print(f'Error cloning repository: {e.output.decode()}')
 
 
 
 
 
 
 
 
 
 
 
 
21
  # %%
22
  with open(f'{model_folder}/labels.txt', 'r') as f:
23
  labels = f.read().split('\n')
24
 
25
  # model = tf.saved_model.load(f'{model_folder}/last_layer.hdf5')
26
- model = tf.keras.models.load_model(f'{model_folder}/last_layer.hdf5')
 
27
  # %%
28
  def classify_image(inp):
29
  inp = cv2.resize(inp, (224,224,))
@@ -34,7 +46,30 @@ def classify_image(inp):
34
  confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))}
35
  return confidences
36
 
37
- gr.Interface(fn=classify_image,
38
- inputs=gr.Image(shape=(224, 224)),
39
- outputs=gr.Label(num_top_classes=3),
40
- examples=["TomatoHealthy2.jpg", "TomatoYellowCurlVirus3.jpg", "AppleCedarRust3.jpg"]).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # %%
2
  import gradio as gr
3
  import tensorflow as tf
 
4
  import cv2
5
  import os
6
 
 
17
  print('Repository cloned successfully.')
18
  except subprocess.CalledProcessError as e:
19
  print(f'Error cloning repository: {e.output.decode()}')
20
+
21
+ if not os.path.exists(destination):
22
+ import subprocess
23
+ repo_url = os.getenv("GIT_CORE")
24
+ command = f'git clone {repo_url}'
25
+ try:
26
+ subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True)#, env=env)
27
+ print('Repository cloned successfully.')
28
+ except subprocess.CalledProcessError as e:
29
+ print(f'Error cloning repository: {e.output.decode()}')
30
+
31
+ from explainer_tf_mobilenetv2.explainer import explainer
32
  # %%
33
  with open(f'{model_folder}/labels.txt', 'r') as f:
34
  labels = f.read().split('\n')
35
 
36
  # model = tf.saved_model.load(f'{model_folder}/last_layer.hdf5')
37
+ # model = tf.keras.models.load_model(f'{model_folder}/last_layer.hdf5')
38
+ model = tf.keras.models.load_model(f'{model_folder}/MobileNetV2_last_layer.hdf5')
39
  # %%
40
  def classify_image(inp):
41
  inp = cv2.resize(inp, (224,224,))
 
46
  confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))}
47
  return confidences
48
 
49
+ def explainer_wrapper(inp):
50
+ return explainer(inp, model)
51
+
52
+ with gr.Blocks() as demo:
53
+ with gr.Column():
54
+ with gr.Row():
55
+ with gr.Column():
56
+ image = gr.inputs.Image(shape=(224, 224))
57
+ with gr.Row():
58
+ classify = gr.Button("Classify")
59
+ interpret = gr.Button("Interpret")
60
+
61
+ label = gr.outputs.Label(num_top_classes=3)
62
+ interpretation = gr.Plot(label="Interpretation")
63
+ # interpretation = gr.outputs.Image(type="numpy", label="Interpretation")
64
+ gr.Examples(["TomatoHealthy2.jpg", "TomatoYellowCurlVirus3.jpg", "AppleCedarRust3.jpg"],
65
+ inputs=[image],)
66
+ classify.click(classify_image, image, label, queue=True)
67
+ interpret.click(explainer_wrapper, image, interpretation, queue=True)
68
+
69
+
70
+ demo.queue(concurrency_count=3).launch()
71
+ #%%
72
+ # gr.Interface(fn=classify_image,
73
+ # inputs=gr.Image(shape=(224, 224)),
74
+ # outputs=gr.Label(num_top_classes=3),
75
+ # examples=["TomatoHealthy2.jpg", "TomatoYellowCurlVirus3.jpg", "AppleCedarRust3.jpg"]).launch()