Spaces:
Sleeping
Sleeping
Commit
·
80b6d82
1
Parent(s):
e3c1c26
change image to pil format
Browse files
app.py
CHANGED
@@ -94,7 +94,7 @@ class Network(nn.Module):
|
|
94 |
model = load_checkpoint("flower_inference_model.pth")
|
95 |
|
96 |
|
97 |
-
def process_image(
|
98 |
"""Scales, crops, and normalizes a PIL image for a PyTorch model,
|
99 |
returns a Numpy array
|
100 |
|
@@ -102,7 +102,7 @@ def process_image(img_path):
|
|
102 |
---------
|
103 |
image: path of the image to be processed
|
104 |
"""
|
105 |
-
inp = Image.open(
|
106 |
inp.load()
|
107 |
img_exif = inp.getexif()
|
108 |
|
@@ -117,12 +117,12 @@ def process_image(img_path):
|
|
117 |
image_path = re.sub(r"\W+", "_", image_path)
|
118 |
|
119 |
# Join to directory path
|
120 |
-
inf_image = os.path.join("inference", image_path)
|
121 |
|
122 |
# Use repo for inference
|
123 |
-
inp.save(inf_image,
|
124 |
HfApi().upload_file(
|
125 |
-
path_or_fileobj=inf_image
|
126 |
path_in_repo=image_path,
|
127 |
repo_id="DanielPFlorian/flower-image-classifier",
|
128 |
repo_type="dataset",
|
@@ -164,7 +164,7 @@ with open("cat_to_name.json", "r") as f:
|
|
164 |
cat_to_name = json.load(f)
|
165 |
|
166 |
|
167 |
-
def predict(
|
168 |
"""Predict the class (or classes) of an image using a trained deep learning model.
|
169 |
Arguments
|
170 |
---------
|
@@ -173,7 +173,7 @@ def predict(image_path, model=model, category_names=cat_to_name, topk=5):
|
|
173 |
topk: number of top predicted classes to return
|
174 |
"""
|
175 |
# Process image function
|
176 |
-
image = process_image(
|
177 |
|
178 |
# Convert image to float tensor with batch size of 1
|
179 |
image = torch.as_tensor(image).view((1, 3, 224, 224)).float()
|
@@ -210,7 +210,7 @@ def predict(image_path, model=model, category_names=cat_to_name, topk=5):
|
|
210 |
|
211 |
# Plot Functionality
|
212 |
|
213 |
-
image = Image.open(
|
214 |
fig, (ax1, ax2) = plt.subplots(ncols=2)
|
215 |
ax1.imshow(image)
|
216 |
ax1.axis("off")
|
@@ -231,7 +231,7 @@ def predict(image_path, model=model, category_names=cat_to_name, topk=5):
|
|
231 |
# Gradio Interface
|
232 |
gr.Interface(
|
233 |
predict,
|
234 |
-
inputs=gr.inputs.Image(label="Upload a flower image", type="
|
235 |
outputs=gr.Plot(label="Plot"),
|
236 |
title="What kind of flower is this?",
|
237 |
).launch()
|
|
|
94 |
model = load_checkpoint("flower_inference_model.pth")
|
95 |
|
96 |
|
97 |
+
def process_image(pil_image):
|
98 |
"""Scales, crops, and normalizes a PIL image for a PyTorch model,
|
99 |
returns a Numpy array
|
100 |
|
|
|
102 |
---------
|
103 |
image: path of the image to be processed
|
104 |
"""
|
105 |
+
inp = Image.open(pil_image)
|
106 |
inp.load()
|
107 |
img_exif = inp.getexif()
|
108 |
|
|
|
117 |
image_path = re.sub(r"\W+", "_", image_path)
|
118 |
|
119 |
# Join to directory path
|
120 |
+
inf_image = os.path.join("inference", image_path + ".jpg")
|
121 |
|
122 |
# Use repo for inference
|
123 |
+
inp.save(inf_image, quality=95, keep=True, exif=img_exif)
|
124 |
HfApi().upload_file(
|
125 |
+
path_or_fileobj=inf_image,
|
126 |
path_in_repo=image_path,
|
127 |
repo_id="DanielPFlorian/flower-image-classifier",
|
128 |
repo_type="dataset",
|
|
|
164 |
cat_to_name = json.load(f)
|
165 |
|
166 |
|
167 |
+
def predict(pil_image, model=model, category_names=cat_to_name, topk=5):
|
168 |
"""Predict the class (or classes) of an image using a trained deep learning model.
|
169 |
Arguments
|
170 |
---------
|
|
|
173 |
topk: number of top predicted classes to return
|
174 |
"""
|
175 |
# Process image function
|
176 |
+
image = process_image(pil_image)
|
177 |
|
178 |
# Convert image to float tensor with batch size of 1
|
179 |
image = torch.as_tensor(image).view((1, 3, 224, 224)).float()
|
|
|
210 |
|
211 |
# Plot Functionality
|
212 |
|
213 |
+
image = Image.open(pil_image)
|
214 |
fig, (ax1, ax2) = plt.subplots(ncols=2)
|
215 |
ax1.imshow(image)
|
216 |
ax1.axis("off")
|
|
|
231 |
# Gradio Interface
|
232 |
gr.Interface(
|
233 |
predict,
|
234 |
+
inputs=gr.inputs.Image(label="Upload a flower image", type="pil"),
|
235 |
outputs=gr.Plot(label="Plot"),
|
236 |
title="What kind of flower is this?",
|
237 |
).launch()
|