runaksh commited on
Commit
a55adf1
·
verified ·
1 Parent(s): 8b854a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -54
app.py CHANGED
@@ -1,3 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import ViTForImageClassification, ViTFeatureExtractor
3
  from PIL import Image
@@ -5,75 +50,43 @@ import torch
5
  import numpy as np
6
 
7
  # Load the pre-trained model and preprocessor (feature extractor)
8
- model_name = "runaksh/chest_xray_pneumonia_detection"
9
- model = ViTForImageClassification.from_pretrained(model_name)
10
  feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
11
 
12
- def classify_image_pneumonia(image):
13
  # Convert the PIL Image to a format compatible with the feature extractor
14
- image = np.array(image)
15
  # Preprocess the image and prepare it for the model
16
- inputs = feature_extractor(images=image, return_tensors="pt")
17
  # Make prediction
18
  with torch.no_grad():
19
- outputs = model(**inputs)
20
- logits = outputs.logits
21
  # Retrieve the highest probability class label index
22
- predicted_class_idx = logits.argmax(-1).item()
23
  # Define a manual mapping of label indices to human-readable labels
24
- index_to_label = {
25
  0: "NORMAL",
26
  1: "PNEUMONIA"
27
  }
28
 
29
  # Convert the index to the model's class label
30
- label = index_to_label.get(predicted_class_idx, "Unknown Label")
31
 
32
- return label
33
 
34
- def classify_image_tuberculosis(image):
35
- # Convert the PIL Image to a format compatible with the feature extractor
36
- image = np.array(image)
37
- # Preprocess the image and prepare it for the model
38
- inputs = feature_extractor(images=image, return_tensors="pt")
39
- # Make prediction
40
- with torch.no_grad():
41
- outputs = model(**inputs)
42
- logits = outputs.logits
43
- # Retrieve the highest probability class label index
44
- predicted_class_idx = logits.argmax(-1).item()
45
- # Define a manual mapping of label indices to human-readable labels
46
- index_to_label = {
47
- 0: "PNEUMONIA = NO",
48
- 1: "PNEUMONIA = YES"
49
- }
50
 
51
- # Convert the index to the model's class label
52
- label = index_to_label.get(predicted_class_idx, "Unknown Label")
 
 
 
 
53
 
54
- return label
 
55
 
56
- # Create Gradio interface
57
- def make_block(dem):
58
- with dem:
59
- gr.Markdown("Medical - Lungs Disease Prediction")
60
- with gr.Tabs():
61
- with gr.TabItem("Pneumonia Detection"):
62
- with gr.Row():
63
- in_prompt_1 = gr.Image()
64
- out_response_1 = gr.Label()
65
- b1 = gr.Button("Enter")
66
-
67
- with gr.TabItem("Tuberculosis Detection"):
68
- with gr.Row():
69
- in_prompt_2 = gr.Image()
70
- out_response_2 = gr.Label()
71
- b2 = gr.Button("Enter")
72
- b1.Block(classify_image_pneumonia, inputs=in_prompt_1, outputs=out_response_1)
73
- b2.Block(classify_image_tuberculosis, inputs=in_prompt_2, outputs=out_response_2)
74
-
75
- if __name__ == '__main__':
76
-
77
- demo = gr.Blocks()
78
- make_block(demo)
79
- demo.launch()
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+ Search models, datasets, users...
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Posts
8
+ Docs
9
+ Solutions
10
+ Pricing
11
+
12
+
13
+
14
+ Spaces:
15
+
16
+ runaksh
17
+ /
18
+ chest_xray_pneumonia_detection
19
+
20
+
21
+ like
22
+ 0
23
+
24
+ Logs
25
+ App
26
+ Files
27
+ Community
28
+ Settings
29
+ chest_xray_pneumonia_detection
30
+ /
31
+ app.py
32
+
33
+ runaksh's picture
34
+ runaksh
35
+ Update app.py
36
+ 3e586d9
37
+ VERIFIED
38
+ 10 days ago
39
+ raw
40
+ history
41
+ blame
42
+ edit
43
+ delete
44
+ No virus
45
+ 1.58 kB
46
  import gradio as gr
47
  from transformers import ViTForImageClassification, ViTFeatureExtractor
48
  from PIL import Image
 
50
  import numpy as np
51
 
52
  # Load the pre-trained model and preprocessor (feature extractor)
53
+ model_name_pneumonia = "runaksh/chest_xray_pneumonia_detection"
54
+ model_pneumonia = ViTForImageClassification.from_pretrained(model_name_pneumonia)
55
  feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
56
 
57
+ def classify_image(image):
58
  # Convert the PIL Image to a format compatible with the feature extractor
59
+ image_pneumonia = np.array(image)
60
  # Preprocess the image and prepare it for the model
61
+ inputs_pneumonia = feature_extractor(images=image, return_tensors="pt")
62
  # Make prediction
63
  with torch.no_grad():
64
+ outputs_pneumonia = model_pneumonia(**inputs_pneumonia)
65
+ logits_pneumonia = outputs.logits
66
  # Retrieve the highest probability class label index
67
+ predicted_class_idx_pneumonia = logits_pneumonia.argmax(-1).item()
68
  # Define a manual mapping of label indices to human-readable labels
69
+ index_to_label_pneumonia = {
70
  0: "NORMAL",
71
  1: "PNEUMONIA"
72
  }
73
 
74
  # Convert the index to the model's class label
75
+ label_pneumonia = index_to_label_pneumonia.get(predicted_class_idx_pneumonia, "Unknown Label")
76
 
77
+ return label_pneumonia
78
 
79
+ # Create title, description and article strings
80
+ title = "Classification Demo"
81
+ description = "XRay classification"
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # Create Gradio interface
84
+ iface = gr.Interface(fn=classify_image,
85
+ inputs=gr.Image(), # Accepts image of any size
86
+ outputs=gr.label_pneumonia(),
87
+ title=title,
88
+ description=description)
89
 
90
+ # Launch the app
91
+ iface.launch()
92