akhaliq HF Staff commited on
Commit
a739b51
·
1 Parent(s): bdcf6dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -44
app.py CHANGED
@@ -1,66 +1,60 @@
1
- import mxnet as mx
2
- import matplotlib.pyplot as plt
3
  import numpy as np
4
- from collections import namedtuple
5
- from mxnet.gluon.data.vision import transforms
 
6
  import os
7
  import gradio as gr
8
 
9
- from PIL import Image
10
- import imageio
11
- import onnxruntime as ort
12
-
13
- def get_image(path):
14
- '''
15
- Using path to image, return the RGB load image
16
- '''
17
- img = imageio.imread(path, pilmode='RGB')
18
- return img
19
-
20
- # Pre-processing function for ImageNet models using numpy
21
- def preprocess(img):
22
- '''
23
- Preprocessing required on the images for inference with mxnet gluon
24
- The function takes loaded image and returns processed tensor
25
- '''
26
- img = np.array(Image.fromarray(img).resize((224, 224))).astype(np.float32)
27
- img[:, :, 0] -= 123.68
28
- img[:, :, 1] -= 116.779
29
- img[:, :, 2] -= 103.939
30
- img[:,:,[0,1,2]] = img[:,:,[2,1,0]]
31
- img = img.transpose((2, 0, 1))
32
- img = np.expand_dims(img, axis=0)
33
-
34
- return img
35
 
36
- mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
37
 
38
- mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
39
  with open('synset.txt', 'r') as f:
40
  labels = [l.rstrip() for l in f]
41
 
42
  os.system("wget https://github.com/AK391/models/raw/main/vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-9.onnx")
43
 
44
- ort_session = ort.InferenceSession("inception-v2-9.onnx")
 
45
 
46
-
47
- def predict(path):
48
- img_batch = preprocess(get_image(path))
49
 
50
- outputs = ort_session.run(
51
- None,
52
- {"data_0": img_batch.astype(np.float32)},
53
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- a = np.argsort(outputs[0].flatten())[::-1]
 
 
 
 
 
 
56
  results = {}
57
- for i in a[0:5]:
58
- results[labels[i]]=float(outputs[0][0][i])
59
  return results
 
60
 
61
 
62
  title="Inception v2"
63
  description="Inception v2 is a deep convolutional networks for classification."
64
 
65
- examples=[['catonnx.jpg']]
66
  gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)
 
1
+ import onnx
 
2
  import numpy as np
3
+ import onnxruntime as ort
4
+ from PIL import Image
5
+ import cv2
6
  import os
7
  import gradio as gr
8
 
9
+ os.system("wget https://s3.amazonaws.com/onnx-model-zoo/synset.txt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
11
 
 
12
  with open('synset.txt', 'r') as f:
13
  labels = [l.rstrip() for l in f]
14
 
15
  os.system("wget https://github.com/AK391/models/raw/main/vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-9.onnx")
16
 
17
+ os.system("wget https://s3.amazonaws.com/model-server/inputs/kitten.jpg")
18
+
19
 
 
 
 
20
 
21
+ model_path = 'inception-v2-9.onnx'
22
+ model = onnx.load(model_path)
23
+ session = ort.InferenceSession(model.SerializeToString())
24
+
25
+ def get_image(path):
26
+ with Image.open(path) as img:
27
+ img = np.array(img.convert('RGB'))
28
+ return img
29
+
30
+ def preprocess(img):
31
+ img = img / 255.
32
+ img = cv2.resize(img, (256, 256))
33
+ h, w = img.shape[0], img.shape[1]
34
+ y0 = (h - 224) // 2
35
+ x0 = (w - 224) // 2
36
+ img = img[y0 : y0+224, x0 : x0+224, :]
37
+ img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
38
+ img = np.transpose(img, axes=[2, 0, 1])
39
+ img = img.astype(np.float32)
40
+ img = np.expand_dims(img, axis=0)
41
+ return img
42
 
43
+ def predict(path):
44
+ img = get_image(path)
45
+ img = preprocess(img)
46
+ ort_inputs = {session.get_inputs()[0].name: img}
47
+ preds = session.run(None, ort_inputs)[0]
48
+ preds = np.squeeze(preds)
49
+ a = np.argsort(preds)[::-1]
50
  results = {}
51
+ results[labels[a[0]]] = float(preds[a[0]]*0.1)
 
52
  return results
53
+
54
 
55
 
56
  title="Inception v2"
57
  description="Inception v2 is a deep convolutional networks for classification."
58
 
59
+ examples=[['kitten.jpg']]
60
  gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)