akhaliq HF Staff commited on
Commit
559c82e
·
1 Parent(s): 716d27f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -18
app.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import gradio as gr
8
 
9
  import mxnet
10
- from mxnet.gluon.data.vision import transforms
11
 
12
  os.system("wget https://s3.amazonaws.com/onnx-model-zoo/synset.txt")
13
 
@@ -30,28 +30,21 @@ def get_image(path):
30
  img = np.array(img.convert('RGB'))
31
  return img
32
 
33
- def preprocess(img):
34
- '''
35
- Preprocessing required on the images for inference with mxnet gluon
36
- The function takes path to an image and returns processed tensor
37
- '''
38
- transform_fn = transforms.Compose([
39
- transforms.Resize(224),
40
- transforms.CenterCrop(224),
41
- transforms.ToTensor(),
42
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
43
- ])
44
- img = mxnet.ndarray.array(img)
45
- img = transform_fn(img)
46
- img = img.expand_dims(axis=0) # batchify
47
 
48
- return img.asnumpy()
49
 
50
 
51
  def predict(path):
52
  img = get_image(path)
53
- img = preprocess(img)
54
- ort_inputs = {session.get_inputs()[0].name: img}
 
55
  preds = session.run(None, ort_inputs)[0]
56
  preds = np.squeeze(preds)
57
  a = np.argsort(preds)
 
7
  import gradio as gr
8
 
9
  import mxnet
10
+ from torchvision import transforms
11
 
12
  os.system("wget https://s3.amazonaws.com/onnx-model-zoo/synset.txt")
13
 
 
30
  img = np.array(img.convert('RGB'))
31
  return img
32
 
33
+ preprocess = transforms.Compose([
34
+ transforms.Resize(256),
35
+ transforms.CenterCrop(224),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
38
+ ])
39
+
 
 
 
 
 
 
 
40
 
 
41
 
42
 
43
  def predict(path):
44
  img = get_image(path)
45
+ input_tensor = preprocess(img)
46
+ img = input_tensor.unsqueeze(0)
47
+ ort_inputs = {session.get_inputs()[0].name: img.cpu().detach().numpy()}
48
  preds = session.run(None, ort_inputs)[0]
49
  preds = np.squeeze(preds)
50
  a = np.argsort(preds)