Spravil commited on
Commit
c80591e
1 Parent(s): 2a855ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -18,9 +18,8 @@ def predict(model_name, image):
18
  input_tensor = transform(image).unsqueeze(0)
19
  with torch.no_grad():
20
  output = model(input_tensor)
21
- output_np = output[0].numpy()
22
- class_ind = np.argmax(output_np)
23
- return class_names[class_ind]
24
 
25
  interface = gr.Interface(
26
  fn=predict,
@@ -28,7 +27,7 @@ interface = gr.Interface(
28
  gr.Dropdown(label="Select Model", value="hb_former_b36", choices=["hpx_former_s18", "hpx_former_s18_384", "hb_former_s18", "c_hpx_former_s18", "hpx_a_former_s18", "hb_a_former_s18", "hpx_former_b36", "hb_former_b36"]),
29
  gr.Image(type="pil", label="Upload Image")
30
  ],
31
- outputs=gr.Textbox(label="Predicted Class"),
32
  title="Image Classification",
33
  description="Choose a model and upload an image to predict the class."
34
  )
 
18
  input_tensor = transform(image).unsqueeze(0)
19
  with torch.no_grad():
20
  output = model(input_tensor)
21
+ output_np = torch.softmax(output)[0].numpy()
22
+ return {clsname: prob for clsname, prob in zip(class_names, output_np)}
 
23
 
24
  interface = gr.Interface(
25
  fn=predict,
 
27
  gr.Dropdown(label="Select Model", value="hb_former_b36", choices=["hpx_former_s18", "hpx_former_s18_384", "hb_former_s18", "c_hpx_former_s18", "hpx_a_former_s18", "hb_a_former_s18", "hpx_former_b36", "hb_former_b36"]),
28
  gr.Image(type="pil", label="Upload Image")
29
  ],
30
+ outputs=gr.Label(label="Prediction", num_top_classes=10),
31
  title="Image Classification",
32
  description="Choose a model and upload an image to predict the class."
33
  )