DanielPFlorian commited on
Commit
2b228e6
·
1 Parent(s): 0bce035

load model checkpoint

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -88,6 +88,8 @@ class Network(nn.Module):
88
 
89
  return F.log_softmax(x, dim=1)
90
 
 
 
91
  def process_image(img_path):
92
  """Scales, crops, and normalizes a PIL image for a PyTorch model,
93
  returns a Numpy array
@@ -198,7 +200,7 @@ def predict(image_path, model=model, category_names=cat_to_name, topk=5):
198
  # Plot Functionality
199
 
200
  image = Image.open(image_path)
201
- fig, (ax1, ax2) = plt.subplots(ncols=2)
202
  ax1.imshow(image)
203
  ax1.axis("off")
204
  ax2.barh(np.arange(len(top_labels)), percentages)
 
88
 
89
  return F.log_softmax(x, dim=1)
90
 
91
+ model = load_checkpoint("flower_inference_model.pth")
92
+
93
  def process_image(img_path):
94
  """Scales, crops, and normalizes a PIL image for a PyTorch model,
95
  returns a Numpy array
 
200
  # Plot Functionality
201
 
202
  image = Image.open(image_path)
203
+ fig, (ax1, ax2) = plt.subplots(figsize=(8, 8), ncols=2)
204
  ax1.imshow(image)
205
  ax1.axis("off")
206
  ax2.barh(np.arange(len(top_labels)), percentages)