akhaliq HF Staff commited on
Commit
608aad9
·
1 Parent(s): cfdbae0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -25,10 +25,7 @@ model_path = 'shufflenet-v2-10.onnx'
25
  model = onnx.load(model_path)
26
  session = ort.InferenceSession(model.SerializeToString())
27
 
28
- def get_image(path):
29
- with Image.open(path) as img:
30
- img = np.array(img.convert('RGB'))
31
- return img
32
 
33
  preprocess = transforms.Compose([
34
  transforms.Resize(256),
@@ -40,8 +37,7 @@ preprocess = transforms.Compose([
40
 
41
 
42
 
43
- def predict(path):
44
- img = get_image(path)
45
  input_tensor = preprocess(img)
46
  img = input_tensor.unsqueeze(0)
47
  ort_inputs = {session.get_inputs()[0].name: img.cpu().detach().numpy()}
@@ -58,4 +54,4 @@ title="ShuffleNet-v2"
58
  description="ShuffleNet is a deep convolutional network for image classification. ShuffleNetV2 is an improved architecture that is the state-of-the-art in terms of speed and accuracy tradeoff used for image classification."
59
 
60
  examples=[['kitten.jpg']]
61
- gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)
 
25
  model = onnx.load(model_path)
26
  session = ort.InferenceSession(model.SerializeToString())
27
 
28
+
 
 
 
29
 
30
  preprocess = transforms.Compose([
31
  transforms.Resize(256),
 
37
 
38
 
39
 
40
+ def predict(img):
 
41
  input_tensor = preprocess(img)
42
  img = input_tensor.unsqueeze(0)
43
  ort_inputs = {session.get_inputs()[0].name: img.cpu().detach().numpy()}
 
54
  description="ShuffleNet is a deep convolutional network for image classification. ShuffleNetV2 is an improved architecture that is the state-of-the-art in terms of speed and accuracy tradeoff used for image classification."
55
 
56
  examples=[['kitten.jpg']]
57
+ gr.Interface(predict,gr.inputs.Image(type='pil'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)