runaksh commited on
Commit
5a131cf
·
verified ·
1 Parent(s): 33a8eb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -11
app.py CHANGED
@@ -5,31 +5,31 @@ import torch
5
  import numpy as np
6
 
7
  # Load the pre-trained model and preprocessor (feature extractor)
8
- model_name_pneumonia = "runaksh/chest_xray_pneumonia_detection"
9
- model_pneumonia = ViTForImageClassification.from_pretrained(model_name_pneumonia)
10
  feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
11
 
12
  def classify_image(image):
13
  # Convert the PIL Image to a format compatible with the feature extractor
14
  image = np.array(image)
15
  # Preprocess the image and prepare it for the model
16
- inputs_pneumonia = feature_extractor(images=image, return_tensors="pt")
17
  # Make prediction
18
  with torch.no_grad():
19
- outputs_pneumonia = model_pneumonia(**inputs_pneumonia)
20
- logits_pneumonia = outputs_pneumonia.logits_pneumonia
21
  # Retrieve the highest probability class label index
22
- predicted_class_idx_pneumonia = logits_pneumonia.argmax(-1).item()
23
  # Define a manual mapping of label indices to human-readable labels
24
- index_to_label_pneumonia = {
25
  0: "NORMAL",
26
  1: "PNEUMONIA"
27
  }
28
 
29
  # Convert the index to the model's class label
30
- label_pneumonia = index_to_label_pneumonia.get(predicted_class_idx_pneumonia, "Unknown Label")
31
 
32
- return label_pneumonia
33
 
34
  # Create title, description and article strings
35
  title = "Classification Demo"
@@ -43,5 +43,4 @@ iface = gr.Interface(fn=classify_image,
43
  description=description)
44
 
45
  # Launch the app
46
- iface.launch()
47
-
 
5
  import numpy as np
6
 
7
  # Load the pre-trained model and preprocessor (feature extractor)
8
+ model_name = "runaksh/chest_xray_pneumonia_detection"
9
+ model = ViTForImageClassification.from_pretrained(model_name)
10
  feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
11
 
12
  def classify_image(image):
13
  # Convert the PIL Image to a format compatible with the feature extractor
14
  image = np.array(image)
15
  # Preprocess the image and prepare it for the model
16
+ inputs = feature_extractor(images=image, return_tensors="pt")
17
  # Make prediction
18
  with torch.no_grad():
19
+ outputs = model(**inputs)
20
+ logits = outputs.logits
21
  # Retrieve the highest probability class label index
22
+ predicted_class_idx = logits.argmax(-1).item()
23
  # Define a manual mapping of label indices to human-readable labels
24
+ index_to_label = {
25
  0: "NORMAL",
26
  1: "PNEUMONIA"
27
  }
28
 
29
  # Convert the index to the model's class label
30
+ label = index_to_label.get(predicted_class_idx, "Unknown Label")
31
 
32
+ return label
33
 
34
  # Create title, description and article strings
35
  title = "Classification Demo"
 
43
  description=description)
44
 
45
  # Launch the app
46
+ iface.launch()