DanielPFlorian commited on
Commit
5dc340d
·
1 Parent(s): 7e7870c

forgot comma

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -19,6 +19,7 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
19
 
20
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
 
 
22
  def load_checkpoint(filepath):
23
  """Builds PyTorch Model from saved model
24
  Returns built model
@@ -88,9 +89,11 @@ class Network(nn.Module):
88
  x = self.output(x)
89
 
90
  return F.log_softmax(x, dim=1)
91
-
 
92
  model = load_checkpoint("flower_inference_model.pth")
93
-
 
94
  def process_image(img_path):
95
  """Scales, crops, and normalizes a PIL image for a PyTorch model,
96
  returns a Numpy array
@@ -154,6 +157,11 @@ def process_image(img_path):
154
 
155
  return np_image
156
 
 
 
 
 
 
157
  def predict(image_path, model=model, category_names=cat_to_name, topk=5):
158
  """Predict the class (or classes) of an image using a trained deep learning model.
159
  Arguments
@@ -216,10 +224,11 @@ def predict(image_path, model=model, category_names=cat_to_name, topk=5):
216
  plt.show()
217
 
218
  return fig
219
-
 
220
  gr.Interface(
221
  predict,
222
  inputs=gr.inputs.Image(label="Upload a flower image", type="filepath"),
223
- outputs=gr.outputs.Label(num_top_classes=5),
224
- title="What kind of flower is this?"
225
- ).launch()
 
19
 
20
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
 
22
+
23
  def load_checkpoint(filepath):
24
  """Builds PyTorch Model from saved model
25
  Returns built model
 
89
  x = self.output(x)
90
 
91
  return F.log_softmax(x, dim=1)
92
+
93
+
94
  model = load_checkpoint("flower_inference_model.pth")
95
+
96
+
97
  def process_image(img_path):
98
  """Scales, crops, and normalizes a PIL image for a PyTorch model,
99
  returns a Numpy array
 
157
 
158
  return np_image
159
 
160
+
161
+ with open("cat_to_name.json", "r") as f:
162
+ cat_to_name = json.load(f)
163
+
164
+
165
  def predict(image_path, model=model, category_names=cat_to_name, topk=5):
166
  """Predict the class (or classes) of an image using a trained deep learning model.
167
  Arguments
 
224
  plt.show()
225
 
226
  return fig
227
+
228
+
229
  gr.Interface(
230
  predict,
231
  inputs=gr.inputs.Image(label="Upload a flower image", type="filepath"),
232
+ outputs=gr.Plot(label="Plot"),
233
+ title="What kind of flower is this?",
234
+ ).launch()