runaksh commited on
Commit
f9b41fc
·
verified ·
1 Parent(s): 88967f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -3
app.py CHANGED
@@ -9,7 +9,29 @@ model_name = "runaksh/chest_xray_pneumonia_detection"
9
  model = ViTForImageClassification.from_pretrained(model_name)
10
  feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
11
 
12
- def classify_image(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # Convert the PIL Image to a format compatible with the feature extractor
14
  image = np.array(image)
15
  # Preprocess the image and prepare it for the model
@@ -47,8 +69,8 @@ def make_block(dem):
47
  in_prompt_2 = gr.Image()
48
  out_response_2 = gr.Label()
49
  b2 = gr.Button("Enter")
50
- b1.click(predict_sentiment, inputs=in_prompt_1, outputs=out_response_1)
51
- b2.click(predict, inputs=in_prompt_2, outputs=out_response_2)
52
 
53
  if __name__ == '__main__':
54
 
 
9
  model = ViTForImageClassification.from_pretrained(model_name)
10
  feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
11
 
12
+ def classify_image_pneumonia(image):
13
+ # Convert the PIL Image to a format compatible with the feature extractor
14
+ image = np.array(image)
15
+ # Preprocess the image and prepare it for the model
16
+ inputs = feature_extractor(images=image, return_tensors="pt")
17
+ # Make prediction
18
+ with torch.no_grad():
19
+ outputs = model(**inputs)
20
+ logits = outputs.logits
21
+ # Retrieve the highest probability class label index
22
+ predicted_class_idx = logits.argmax(-1).item()
23
+ # Define a manual mapping of label indices to human-readable labels
24
+ index_to_label = {
25
+ 0: "NORMAL",
26
+ 1: "PNEUMONIA"
27
+ }
28
+
29
+ # Convert the index to the model's class label
30
+ label = index_to_label.get(predicted_class_idx, "Unknown Label")
31
+
32
+ return label
33
+
34
+ def classify_image_tuberculosis(image):
35
  # Convert the PIL Image to a format compatible with the feature extractor
36
  image = np.array(image)
37
  # Preprocess the image and prepare it for the model
 
69
  in_prompt_2 = gr.Image()
70
  out_response_2 = gr.Label()
71
  b2 = gr.Button("Enter")
72
+ b1.click(classify_image_pneumonia, inputs=in_prompt_1, outputs=out_response_1)
73
+ b2.click(classify_image_tuberculosis, inputs=in_prompt_2, outputs=out_response_2)
74
 
75
  if __name__ == '__main__':
76