Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ import os
|
|
7 |
import gradio as gr
|
8 |
|
9 |
import mxnet
|
10 |
-
from
|
11 |
|
12 |
os.system("wget https://s3.amazonaws.com/onnx-model-zoo/synset.txt")
|
13 |
|
@@ -30,28 +30,21 @@ def get_image(path):
|
|
30 |
img = np.array(img.convert('RGB'))
|
31 |
return img
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
transforms.CenterCrop(224),
|
41 |
-
transforms.ToTensor(),
|
42 |
-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
43 |
-
])
|
44 |
-
img = mxnet.ndarray.array(img)
|
45 |
-
img = transform_fn(img)
|
46 |
-
img = img.expand_dims(axis=0) # batchify
|
47 |
|
48 |
-
return img.asnumpy()
|
49 |
|
50 |
|
51 |
def predict(path):
|
52 |
img = get_image(path)
|
53 |
-
|
54 |
-
|
|
|
55 |
preds = session.run(None, ort_inputs)[0]
|
56 |
preds = np.squeeze(preds)
|
57 |
a = np.argsort(preds)
|
|
|
7 |
import gradio as gr
|
8 |
|
9 |
import mxnet
|
10 |
+
from torchvision import transforms
|
11 |
|
12 |
os.system("wget https://s3.amazonaws.com/onnx-model-zoo/synset.txt")
|
13 |
|
|
|
30 |
img = np.array(img.convert('RGB'))
|
31 |
return img
|
32 |
|
33 |
+
preprocess = transforms.Compose([
|
34 |
+
transforms.Resize(256),
|
35 |
+
transforms.CenterCrop(224),
|
36 |
+
transforms.ToTensor(),
|
37 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
38 |
+
])
|
39 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
|
|
41 |
|
42 |
|
43 |
def predict(path):
|
44 |
img = get_image(path)
|
45 |
+
input_tensor = preprocess(img)
|
46 |
+
img = input_tensor.unsqueeze(0)
|
47 |
+
ort_inputs = {session.get_inputs()[0].name: img.cpu().detach().numpy()}
|
48 |
preds = session.run(None, ort_inputs)[0]
|
49 |
preds = np.squeeze(preds)
|
50 |
a = np.argsort(preds)
|