ItsNotRohit commited on
Commit
8458aec
·
1 Parent(s): 9d8d906

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -23,6 +23,7 @@ model.load_state_dict(
23
  )
24
  )
25
 
 
26
  def predict(img) -> Tuple[Dict, float]:
27
  start_time = timer()
28
 
@@ -40,17 +41,12 @@ def predict(img) -> Tuple[Dict, float]:
40
  outputs = model(image).logits
41
  predicted_probs = torch.softmax(outputs, dim=1)
42
 
43
- # Get the class name and its associated probability
44
- predicted_class_idx = torch.argmax(predicted_probs).item()
45
- predicted_class = class_names[predicted_class_idx]
46
- predicted_probability = predicted_probs[0][predicted_class_idx].item()
47
 
48
  # Calculate the prediction time
49
  pred_time = round(timer() - start_time, 5)
50
 
51
- # Create a prediction label and prediction probability dictionary for the predicted class
52
- pred_labels_and_probs = {predicted_class: predicted_probability}
53
-
54
  return pred_labels_and_probs, pred_time
55
 
56
 
 
23
  )
24
  )
25
 
26
+
27
  def predict(img) -> Tuple[Dict, float]:
28
  start_time = timer()
29
 
 
41
  outputs = model(image).logits
42
  predicted_probs = torch.softmax(outputs, dim=1)
43
 
44
+ # Create a prediction label and prediction probability dictionary for each prediction class
45
+ pred_labels_and_probs = {class_names[i]: float(predicted_probs[0][i]) for i in range(len(class_names))}
 
 
46
 
47
  # Calculate the prediction time
48
  pred_time = round(timer() - start_time, 5)
49
 
 
 
 
50
  return pred_labels_and_probs, pred_time
51
 
52