runaksh commited on
Commit
312c423
·
verified ·
1 Parent(s): be4bd47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -5,8 +5,10 @@ import torch
5
  import numpy as np
6
 
7
  # Load the pre-trained model and preprocessor (feature extractor)
8
- model_name = "runaksh/chest_xray_pneumonia_detection"
9
- model = ViTForImageClassification.from_pretrained(model_name)
 
 
10
  feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
11
 
12
  def classify_image(image):
@@ -16,10 +18,10 @@ def classify_image(image):
16
  inputs = feature_extractor(images=image, return_tensors="pt")
17
  # Make prediction
18
  with torch.no_grad():
19
- outputs = model(**inputs)
20
- logits1 = outputs.logits
21
  # Retrieve the highest probability class label index
22
- predicted_class_idx = logits1.argmax(-1).item()
23
  # Define a manual mapping of label indices to human-readable labels
24
  index_to_label = {
25
  0: "NORMAL",
@@ -27,9 +29,9 @@ def classify_image(image):
27
  }
28
 
29
  # Convert the index to the model's class label
30
- label = index_to_label.get(predicted_class_idx, "Unknown Label")
31
 
32
- return label
33
 
34
  # Create title, description and article strings
35
  title = "Classification Demo"
 
5
  import numpy as np
6
 
7
  # Load the pre-trained model and preprocessor (feature extractor)
8
+ model_name_pneumonia = "runaksh/chest_xray_pneumonia_detection"
9
+ model_name_tuberculosis = "runaksh/chest_xray_tuberculosis_detection"
10
+ model_pneumonia = ViTForImageClassification.from_pretrained(model_name_pneumonia)
11
+ model_tuberculosis = ViTForImageClassification.from_pretrained(model_name_tuberculosis)
12
  feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
13
 
14
  def classify_image(image):
 
18
  inputs = feature_extractor(images=image, return_tensors="pt")
19
  # Make prediction
20
  with torch.no_grad():
21
+ outputs_pneumonia = model_pneumonia(**inputs)
22
+ logits_pneumonia = outputs_pneumonia.logits
23
  # Retrieve the highest probability class label index
24
+ predicted_class_idx_pneumonia = logits_pneumonia.argmax(-1).item()
25
  # Define a manual mapping of label indices to human-readable labels
26
  index_to_label = {
27
  0: "NORMAL",
 
29
  }
30
 
31
  # Convert the index to the model's class label
32
+ label_pneumonia = index_to_label.get(predicted_class_idx_pneumonia, "Unknown Label")
33
 
34
+ return label_pneumonia
35
 
36
  # Create title, description and article strings
37
  title = "Classification Demo"