amjadfqs commited on
Commit
159dc7c
·
verified ·
1 Parent(s): f05318e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -15,13 +15,21 @@ def predict(image):
15
  with torch.no_grad():
16
  outputs = model(**inputs)
17
  logits = outputs.logits
 
 
 
18
  # Get the predicted class
19
- predicted_class = logits.argmax(-1).item()
20
- # You may need to adjust the following line based on your class labels
21
  class_names = ["glioma", "meningioma", "notumor", "pituitary"]
22
- return class_names[predicted_class]
 
 
 
 
 
 
23
 
24
  # Set up the Gradio interface
25
  image_cp = gr.Image(type="pil", label='Brain')
26
- interface = gr.Interface(fn=predict, inputs=image_cp, outputs="text")
27
  interface.launch()
 
15
  with torch.no_grad():
16
  outputs = model(**inputs)
17
  logits = outputs.logits
18
+ # Calculate the confidence values
19
+ softmax = torch.nn.functional.softmax(logits, dim=1)
20
+ confidences = softmax.squeeze().tolist()
21
  # Get the predicted class
22
+ predicted_class_index = logits.argmax(-1).item()
 
23
  class_names = ["glioma", "meningioma", "notumor", "pituitary"]
24
+ predicted_class = class_names[predicted_class_index]
25
+ # Create a dictionary to return both the predicted class and the confidence values
26
+ result = {
27
+ "predicted_class": predicted_class,
28
+ "confidences": {class_names[i]: confidences[i] for i in range(len(class_names))}
29
+ }
30
+ return result
31
 
32
  # Set up the Gradio interface
33
  image_cp = gr.Image(type="pil", label='Brain')
34
+ interface = gr.Interface(fn=predict, inputs=image_cp, outputs="json")
35
  interface.launch()