akhaliq HF Staff commited on
Commit
d76175b
·
1 Parent(s): 3f8df5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -26
app.py CHANGED
@@ -10,29 +10,6 @@ from PIL import Image
10
  import imageio
11
  import onnxruntime as ort
12
 
13
- def get_image(path):
14
- '''
15
- Using path to image, return the RGB load image
16
- '''
17
- img = imageio.imread(path, pilmode='RGB')
18
- return img
19
-
20
- # Pre-processing function for ImageNet models using numpy
21
- def preprocess(img):
22
- '''
23
- Preprocessing required on the images for inference with mxnet gluon
24
- The function takes loaded image and returns processed tensor
25
- '''
26
- img = np.array(Image.fromarray(img).resize((224, 224))).astype(np.float32)
27
- img[:, :, 0] -= 123.68
28
- img[:, :, 1] -= 116.779
29
- img[:, :, 2] -= 103.939
30
- img[:,:,[0,1,2]] = img[:,:,[2,1,0]]
31
- img = img.transpose((2, 0, 1))
32
- img = np.expand_dims(img, axis=0)
33
-
34
- return img
35
-
36
  mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
37
 
38
  mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
@@ -45,11 +22,18 @@ ort_session = ort.InferenceSession("shufflenet-v2-10.onnx")
45
 
46
 
47
  def predict(path):
48
- img_batch = preprocess(get_image(path))
49
-
 
 
 
 
 
 
 
50
  outputs = ort_session.run(
51
  None,
52
- {"input": img_batch.astype(np.float32)},
53
  )
54
 
55
  a = np.argsort(outputs[0].flatten())
 
10
  import imageio
11
  import onnxruntime as ort
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
14
 
15
  mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
 
22
 
23
 
24
  def predict(path):
25
+ input_image = Image.open(path)
26
+ preprocess = transforms.Compose([
27
+ transforms.Resize(256),
28
+ transforms.CenterCrop(224),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
31
+ ])
32
+ input_tensor = preprocess(input_image)
33
+ input_batch = input_tensor.unsqueeze(0)
34
  outputs = ort_session.run(
35
  None,
36
+ {"input": input_batch.astype(np.float32)},
37
  )
38
 
39
  a = np.argsort(outputs[0].flatten())