akhaliq HF Staff commited on
Commit
4866e2c
·
1 Parent(s): fd8271e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -31
app.py CHANGED
@@ -1,50 +1,57 @@
1
- import mxnet as mx
2
- import matplotlib.pyplot as plt
3
  import numpy as np
4
- from collections import namedtuple
5
- from mxnet.gluon.data.vision import transforms
6
- import os
7
- import gradio as gr
8
-
9
- from PIL import Image
10
- import imageio
11
  import onnxruntime as ort
12
- from torchvision import transforms
13
-
14
 
15
- preprocess = transforms.Compose([
16
- transforms.Resize(256),
17
- transforms.CenterCrop(224),
18
- transforms.ToTensor(),
19
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
20
- ])
21
 
22
- mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
23
 
24
- mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
25
  with open('synset.txt', 'r') as f:
26
  labels = [l.rstrip() for l in f]
27
 
28
  os.system("wget https://github.com/AK391/models/raw/main/vision/classification/densenet-121/model/densenet-9.onnx")
29
 
30
- ort_session = ort.InferenceSession("densenet-9.onnx")
31
 
32
 
33
- def predict(pil):
34
- input_tensor = preprocess(pil)
35
- img_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
36
- img_batch_np = img_batch.cpu().detach().numpy()
37
 
38
- outputs = ort_session.run(
39
- None,
40
- {"data_0": img_batch_np.astype(np.float32)},
41
- )
42
 
43
- a = np.argsort(-outputs[0].flatten())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  results = {}
45
- for i in a[0:5]:
46
- results[labels[i]]=float(outputs[0][0][i])
47
  return results
 
48
 
49
 
50
  title="DenseNet-121"
 
1
+ import onnx
 
2
  import numpy as np
 
 
 
 
 
 
 
3
  import onnxruntime as ort
4
+ from PIL import Image
5
+ import cv2
6
 
7
+ os.system("wget https://s3.amazonaws.com/onnx-model-zoo/synset.txt")
 
 
 
 
 
8
 
 
9
 
 
10
  with open('synset.txt', 'r') as f:
11
  labels = [l.rstrip() for l in f]
12
 
13
  os.system("wget https://github.com/AK391/models/raw/main/vision/classification/densenet-121/model/densenet-9.onnx")
14
 
15
+ os.system("wget https://s3.amazonaws.com/model-server/inputs/kitten.jpg")
16
 
17
 
 
 
 
 
18
 
19
+ model_path = 'resnet50-v1-12.onnx'
20
+ model = onnx.load(model_path)
21
+ session = ort.InferenceSession(model.SerializeToString())
 
22
 
23
+ def get_image(path, show=False):
24
+ with Image.open(path) as img:
25
+ img = np.array(img.convert('RGB'))
26
+ if show:
27
+ plt.imshow(img)
28
+ plt.axis('off')
29
+ return img
30
+
31
+ def preprocess(img):
32
+ img = img / 255.
33
+ img = cv2.resize(img, (256, 256))
34
+ h, w = img.shape[0], img.shape[1]
35
+ y0 = (h - 224) // 2
36
+ x0 = (w - 224) // 2
37
+ img = img[y0 : y0+224, x0 : x0+224, :]
38
+ img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
39
+ img = np.transpose(img, axes=[2, 0, 1])
40
+ img = img.astype(np.float32)
41
+ img = np.expand_dims(img, axis=0)
42
+ return img
43
+
44
+ def predict(path):
45
+ img = get_image(path, show=True)
46
+ img = preprocess(img)
47
+ ort_inputs = {session.get_inputs()[0].name: img}
48
+ preds = session.run(None, ort_inputs)[0]
49
+ preds = np.squeeze(preds)
50
+ a = np.argsort(preds)[::-1]
51
  results = {}
52
+ results[labels[a[0]]] = preds[a[0]]
 
53
  return results
54
+
55
 
56
 
57
  title="DenseNet-121"