ItsNotRohit commited on
Commit
9d8d906
·
1 Parent(s): 9e38cc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -23,9 +23,7 @@ model.load_state_dict(
23
  )
24
  )
25
 
26
-
27
  def predict(img) -> Tuple[Dict, float]:
28
-
29
  start_time = timer()
30
 
31
  preprocess = transforms.Compose([
@@ -34,16 +32,25 @@ def predict(img) -> Tuple[Dict, float]:
34
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
35
  ])
36
 
37
- img = preprocess(img).unsqueeze(0) # Add batch dimension
38
 
 
39
  model.eval()
40
- with torch.inference_mode():
41
- pred_probs = torch.softmax(model(img), dim=1)
 
42
 
43
- pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
 
 
 
44
 
 
45
  pred_time = round(timer() - start_time, 5)
46
 
 
 
 
47
  return pred_labels_and_probs, pred_time
48
 
49
 
 
23
  )
24
  )
25
 
 
26
  def predict(img) -> Tuple[Dict, float]:
 
27
  start_time = timer()
28
 
29
  preprocess = transforms.Compose([
 
32
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33
  ])
34
 
35
+ image = preprocess(img).unsqueeze(0) # Add batch dimension
36
 
37
+ # Make predictions
38
  model.eval()
39
+ with torch.no_grad():
40
+ outputs = model(image).logits
41
+ predicted_probs = torch.softmax(outputs, dim=1)
42
 
43
+ # Get the class name and its associated probability
44
+ predicted_class_idx = torch.argmax(predicted_probs).item()
45
+ predicted_class = class_names[predicted_class_idx]
46
+ predicted_probability = predicted_probs[0][predicted_class_idx].item()
47
 
48
+ # Calculate the prediction time
49
  pred_time = round(timer() - start_time, 5)
50
 
51
+ # Create a prediction label and prediction probability dictionary for the predicted class
52
+ pred_labels_and_probs = {predicted_class: predicted_probability}
53
+
54
  return pred_labels_and_probs, pred_time
55
 
56