akhaliq HF Staff commited on
Commit
1c25c76
·
1 Parent(s): d76175b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -31
app.py CHANGED
@@ -1,50 +1,60 @@
1
- import mxnet as mx
2
- import matplotlib.pyplot as plt
3
  import numpy as np
4
- from collections import namedtuple
5
- from mxnet.gluon.data.vision import transforms
 
6
  import os
7
  import gradio as gr
8
 
9
- from PIL import Image
10
- import imageio
11
- import onnxruntime as ort
12
 
13
- mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
14
 
15
- mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
16
  with open('synset.txt', 'r') as f:
17
  labels = [l.rstrip() for l in f]
18
 
19
- os.system("wget https://github.com/AK391/models/raw/main/vision/classification/shufflenet/model/shufflenet-v2-10.onnx")
20
 
21
- ort_session = ort.InferenceSession("shufflenet-v2-10.onnx")
22
 
 
 
 
 
 
 
 
 
 
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def predict(path):
25
- input_image = Image.open(path)
26
- preprocess = transforms.Compose([
27
- transforms.Resize(256),
28
- transforms.CenterCrop(224),
29
- transforms.ToTensor(),
30
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
31
- ])
32
- input_tensor = preprocess(input_image)
33
- input_batch = input_tensor.unsqueeze(0)
34
- outputs = ort_session.run(
35
- None,
36
- {"input": input_batch.astype(np.float32)},
37
- )
38
-
39
- a = np.argsort(outputs[0].flatten())
40
  results = {}
41
- for i in a[0:5]:
42
- results[labels[i]]=float(outputs[0][0][i])
43
  return results
44
 
45
 
46
- title="GoogleNet"
47
- description="GoogLeNet is the name of a convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition Challenge in 2014."
48
 
49
- examples=[['catonnx.jpg']]
50
  gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)
 
1
+ import onnx
 
2
  import numpy as np
3
+ import onnxruntime as ort
4
+ from PIL import Image
5
+ import cv2
6
  import os
7
  import gradio as gr
8
 
9
+ os.system("wget https://s3.amazonaws.com/onnx-model-zoo/synset.txt")
 
 
10
 
 
11
 
 
12
  with open('synset.txt', 'r') as f:
13
  labels = [l.rstrip() for l in f]
14
 
15
+ os.system("wget https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-v2-12-int8.onnx")
16
 
17
+ os.system("wget https://s3.amazonaws.com/model-server/inputs/kitten.jpg")
18
 
19
+
20
+
21
+ model_path = 'shufflenet-v2-12-int8.onnx'
22
+ model = onnx.load(model_path)
23
+ session = ort.InferenceSession(model.SerializeToString())
24
+
25
+ def get_image(path):
26
+ with Image.open(path) as img:
27
+ img = np.array(img.convert('RGB'))
28
+ return img
29
 
30
+ def preprocess(img):
31
+ img = img / 255.
32
+ img = cv2.resize(img, (256, 256))
33
+ h, w = img.shape[0], img.shape[1]
34
+ y0 = (h - 224) // 2
35
+ x0 = (w - 224) // 2
36
+ img = img[y0 : y0+224, x0 : x0+224, :]
37
+ img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
38
+ img = np.transpose(img, axes=[2, 0, 1])
39
+ img = img.astype(np.float32)
40
+ img = np.expand_dims(img, axis=0)
41
+ return img
42
+
43
  def predict(path):
44
+ img = get_image(path)
45
+ img = preprocess(img)
46
+ ort_inputs = {session.get_inputs()[0].name: img}
47
+ preds = session.run(None, ort_inputs)[0]
48
+ preds = np.squeeze(preds)
49
+ a = np.argsort(preds)
 
 
 
 
 
 
 
 
 
50
  results = {}
51
+ for i in a[0:5]:
52
+ results[labels[a[i]]] = float(preds[a[i]])
53
  return results
54
 
55
 
56
+ title="ShuffleNet-v2"
57
+ description="ShuffleNet is a deep convolutional network for image classification. ShuffleNetV2 is an improved architecture that is the state-of-the-art in terms of speed and accuracy tradeoff used for image classification."
58
 
59
+ examples=[['kitten.jpg']]
60
  gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)