Spaces:
Sleeping
Sleeping
Commit
·
5dc340d
1
Parent(s):
7e7870c
forgot comma
Browse files
app.py
CHANGED
@@ -19,6 +19,7 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
19 |
|
20 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
21 |
|
|
|
22 |
def load_checkpoint(filepath):
|
23 |
"""Builds PyTorch Model from saved model
|
24 |
Returns built model
|
@@ -88,9 +89,11 @@ class Network(nn.Module):
|
|
88 |
x = self.output(x)
|
89 |
|
90 |
return F.log_softmax(x, dim=1)
|
91 |
-
|
|
|
92 |
model = load_checkpoint("flower_inference_model.pth")
|
93 |
-
|
|
|
94 |
def process_image(img_path):
|
95 |
"""Scales, crops, and normalizes a PIL image for a PyTorch model,
|
96 |
returns a Numpy array
|
@@ -154,6 +157,11 @@ def process_image(img_path):
|
|
154 |
|
155 |
return np_image
|
156 |
|
|
|
|
|
|
|
|
|
|
|
157 |
def predict(image_path, model=model, category_names=cat_to_name, topk=5):
|
158 |
"""Predict the class (or classes) of an image using a trained deep learning model.
|
159 |
Arguments
|
@@ -216,10 +224,11 @@ def predict(image_path, model=model, category_names=cat_to_name, topk=5):
|
|
216 |
plt.show()
|
217 |
|
218 |
return fig
|
219 |
-
|
|
|
220 |
gr.Interface(
|
221 |
predict,
|
222 |
inputs=gr.inputs.Image(label="Upload a flower image", type="filepath"),
|
223 |
-
outputs=gr.
|
224 |
-
title="What kind of flower is this?"
|
225 |
-
).launch()
|
|
|
19 |
|
20 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
21 |
|
22 |
+
|
23 |
def load_checkpoint(filepath):
|
24 |
"""Builds PyTorch Model from saved model
|
25 |
Returns built model
|
|
|
89 |
x = self.output(x)
|
90 |
|
91 |
return F.log_softmax(x, dim=1)
|
92 |
+
|
93 |
+
|
94 |
model = load_checkpoint("flower_inference_model.pth")
|
95 |
+
|
96 |
+
|
97 |
def process_image(img_path):
|
98 |
"""Scales, crops, and normalizes a PIL image for a PyTorch model,
|
99 |
returns a Numpy array
|
|
|
157 |
|
158 |
return np_image
|
159 |
|
160 |
+
|
161 |
+
with open("cat_to_name.json", "r") as f:
|
162 |
+
cat_to_name = json.load(f)
|
163 |
+
|
164 |
+
|
165 |
def predict(image_path, model=model, category_names=cat_to_name, topk=5):
|
166 |
"""Predict the class (or classes) of an image using a trained deep learning model.
|
167 |
Arguments
|
|
|
224 |
plt.show()
|
225 |
|
226 |
return fig
|
227 |
+
|
228 |
+
|
229 |
gr.Interface(
|
230 |
predict,
|
231 |
inputs=gr.inputs.Image(label="Upload a flower image", type="filepath"),
|
232 |
+
outputs=gr.Plot(label="Plot"),
|
233 |
+
title="What kind of flower is this?",
|
234 |
+
).launch()
|