DanielPFlorian commited on
Commit
80b6d82
·
1 Parent(s): e3c1c26

change image to pil format

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -94,7 +94,7 @@ class Network(nn.Module):
94
  model = load_checkpoint("flower_inference_model.pth")
95
 
96
 
97
- def process_image(img_path):
98
  """Scales, crops, and normalizes a PIL image for a PyTorch model,
99
  returns a Numpy array
100
 
@@ -102,7 +102,7 @@ def process_image(img_path):
102
  ---------
103
  image: path of the image to be processed
104
  """
105
- inp = Image.open(img_path)
106
  inp.load()
107
  img_exif = inp.getexif()
108
 
@@ -117,12 +117,12 @@ def process_image(img_path):
117
  image_path = re.sub(r"\W+", "_", image_path)
118
 
119
  # Join to directory path
120
- inf_image = os.path.join("inference", image_path)
121
 
122
  # Use repo for inference
123
- inp.save(inf_image, format="JPEG", quality=95, keep=True, exif=img_exif)
124
  HfApi().upload_file(
125
- path_or_fileobj=inf_image + ".JPG",
126
  path_in_repo=image_path,
127
  repo_id="DanielPFlorian/flower-image-classifier",
128
  repo_type="dataset",
@@ -164,7 +164,7 @@ with open("cat_to_name.json", "r") as f:
164
  cat_to_name = json.load(f)
165
 
166
 
167
- def predict(image_path, model=model, category_names=cat_to_name, topk=5):
168
  """Predict the class (or classes) of an image using a trained deep learning model.
169
  Arguments
170
  ---------
@@ -173,7 +173,7 @@ def predict(image_path, model=model, category_names=cat_to_name, topk=5):
173
  topk: number of top predicted classes to return
174
  """
175
  # Process image function
176
- image = process_image(image_path)
177
 
178
  # Convert image to float tensor with batch size of 1
179
  image = torch.as_tensor(image).view((1, 3, 224, 224)).float()
@@ -210,7 +210,7 @@ def predict(image_path, model=model, category_names=cat_to_name, topk=5):
210
 
211
  # Plot Functionality
212
 
213
- image = Image.open(image_path)
214
  fig, (ax1, ax2) = plt.subplots(ncols=2)
215
  ax1.imshow(image)
216
  ax1.axis("off")
@@ -231,7 +231,7 @@ def predict(image_path, model=model, category_names=cat_to_name, topk=5):
231
  # Gradio Interface
232
  gr.Interface(
233
  predict,
234
- inputs=gr.inputs.Image(label="Upload a flower image", type="filepath"),
235
  outputs=gr.Plot(label="Plot"),
236
  title="What kind of flower is this?",
237
  ).launch()
 
94
  model = load_checkpoint("flower_inference_model.pth")
95
 
96
 
97
+ def process_image(pil_image):
98
  """Scales, crops, and normalizes a PIL image for a PyTorch model,
99
  returns a Numpy array
100
 
 
102
  ---------
103
  image: path of the image to be processed
104
  """
105
+ inp = Image.open(pil_image)
106
  inp.load()
107
  img_exif = inp.getexif()
108
 
 
117
  image_path = re.sub(r"\W+", "_", image_path)
118
 
119
  # Join to directory path
120
+ inf_image = os.path.join("inference", image_path + ".jpg")
121
 
122
  # Use repo for inference
123
+ inp.save(inf_image, quality=95, keep=True, exif=img_exif)
124
  HfApi().upload_file(
125
+ path_or_fileobj=inf_image,
126
  path_in_repo=image_path,
127
  repo_id="DanielPFlorian/flower-image-classifier",
128
  repo_type="dataset",
 
164
  cat_to_name = json.load(f)
165
 
166
 
167
+ def predict(pil_image, model=model, category_names=cat_to_name, topk=5):
168
  """Predict the class (or classes) of an image using a trained deep learning model.
169
  Arguments
170
  ---------
 
173
  topk: number of top predicted classes to return
174
  """
175
  # Process image function
176
+ image = process_image(pil_image)
177
 
178
  # Convert image to float tensor with batch size of 1
179
  image = torch.as_tensor(image).view((1, 3, 224, 224)).float()
 
210
 
211
  # Plot Functionality
212
 
213
+ image = Image.open(pil_image)
214
  fig, (ax1, ax2) = plt.subplots(ncols=2)
215
  ax1.imshow(image)
216
  ax1.axis("off")
 
231
  # Gradio Interface
232
  gr.Interface(
233
  predict,
234
+ inputs=gr.inputs.Image(label="Upload a flower image", type="pil"),
235
  outputs=gr.Plot(label="Plot"),
236
  title="What kind of flower is this?",
237
  ).launch()