akhaliq HF Staff commited on
Commit
e1a7177
·
1 Parent(s): 24f0936

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -24
app.py CHANGED
@@ -9,29 +9,15 @@ import gradio as gr
9
  from PIL import Image
10
  import imageio
11
  import onnxruntime as ort
 
12
 
13
- def get_image(path):
14
- '''
15
- Using path to image, return the RGB load image
16
- '''
17
- img = imageio.imread(path, pilmode='RGB')
18
- return img
19
 
20
- # Pre-processing function for ImageNet models using numpy
21
- def preprocess(img):
22
- '''
23
- Preprocessing required on the images for inference with mxnet gluon
24
- The function takes loaded image and returns processed tensor
25
- '''
26
- img = np.array(Image.fromarray(img).resize((224, 224))).astype(np.float32)
27
- img[:, :, 0] -= 123.68
28
- img[:, :, 1] -= 116.779
29
- img[:, :, 2] -= 103.939
30
- img[:,:,[0,1,2]] = img[:,:,[2,1,0]]
31
- img = img.transpose((2, 0, 1))
32
- img = np.expand_dims(img, axis=0)
33
-
34
- return img
35
 
36
  mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
37
 
@@ -44,8 +30,9 @@ os.system("wget https://github.com/AK391/models/raw/main/vision/classification/d
44
  ort_session = ort.InferenceSession("densenet-9.onnx")
45
 
46
 
47
- def predict(path):
48
- img_batch = preprocess(get_image(path))
 
49
 
50
  outputs = ort_session.run(
51
  None,
@@ -63,4 +50,4 @@ title="DenseNet-121"
63
  description="DenseNet-121 is a convolutional neural network for classification."
64
 
65
  examples=[['apple.jpg']]
66
- gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)
 
9
  from PIL import Image
10
  import imageio
11
  import onnxruntime as ort
12
+ from torchvision import transforms
13
 
 
 
 
 
 
 
14
 
15
+ preprocess = transforms.Compose([
16
+ transforms.Resize(256),
17
+ transforms.CenterCrop(224),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
20
+ ])
 
 
 
 
 
 
 
 
 
21
 
22
  mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
23
 
 
30
  ort_session = ort.InferenceSession("densenet-9.onnx")
31
 
32
 
33
+ def predict(pil):
34
+ input_tensor = preprocess(pil)
35
+ img_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
36
 
37
  outputs = ort_session.run(
38
  None,
 
50
  description="DenseNet-121 is a convolutional neural network for classification."
51
 
52
  examples=[['apple.jpg']]
53
+ gr.Interface(predict,gr.inputs.Image(type='pil'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)