rishabh5752 commited on
Commit
5b13053
·
1 Parent(s): 2805dad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -21,17 +21,27 @@ transform = transforms.Compose([
21
  # Define the class labels
22
  class_labels = ['Normal', 'Pneumonia']
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Create a function to make predictions
25
  def predict(image):
26
- # Convert the image to RGB
27
- image = image.convert('RGB')
28
-
29
  # Preprocess the image
30
- image = transform(image).unsqueeze(0)
31
-
32
  # Make the prediction
33
  with torch.no_grad():
34
- output = model(image)
35
  _, predicted_idx = torch.max(output, 1)
36
  predicted_label = class_labels[predicted_idx.item()]
37
 
 
21
  # Define the class labels
22
  class_labels = ['Normal', 'Pneumonia']
23
 
24
+ # Create a function to preprocess the image
25
+ def preprocess_image(image):
26
+ # Resize the image to match the model's input shape
27
+ image = image.resize((224, 224))
28
+
29
+ # Convert the image to a tensor
30
+ image_tensor = transform(image)
31
+
32
+ # Add a batch dimension
33
+ image_tensor = image_tensor.unsqueeze(0)
34
+
35
+ return image_tensor
36
+
37
  # Create a function to make predictions
38
  def predict(image):
 
 
 
39
  # Preprocess the image
40
+ preprocessed_image = preprocess_image(image)
41
+
42
  # Make the prediction
43
  with torch.no_grad():
44
+ output = model(preprocessed_image)
45
  _, predicted_idx = torch.max(output, 1)
46
  predicted_label = class_labels[predicted_idx.item()]
47