runaksh commited on
Commit
3e586d9
·
verified ·
1 Parent(s): 02040fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -20
app.py CHANGED
@@ -1,29 +1,46 @@
1
- from transformers import ViTFeatureExtractor, ViTModel
 
2
  from PIL import Image
3
- import requests
 
4
 
5
- feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
6
- loaded_model = ViTModel.from_pretrained("runaksh/chest_xray_pneumonia_detection")
7
- #inputs = feature_extractor(images=image, return_tensors="pt")
 
8
 
9
- def predict(img):
10
- #inputs = feature_extractor(images=image, return_tensors="pt")
11
- pipe = pipeline('image-classification', model=model_name, device=0)
12
- pred = pipe(image)
13
- return pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Create title, description and article strings
16
  title = "Classification Demo"
17
  description = "XRay classification"
18
 
19
- # Create the Gradio demo
20
- demo = gr.Interface(fn=predict, # mapping function from input to output
21
- inputs=gr.Image(type='filepath'), # what are the inputs?
22
- outputs=[gr.Label(label="Predictions"), # what are the outputs?
23
- gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
24
- examples=example_list,
25
- title=title,
26
- description=description,)
27
 
28
- # Launch the demo!
29
- demo.launch()
 
1
+ import gradio as gr
2
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
3
  from PIL import Image
4
+ import torch
5
+ import numpy as np
6
 
7
+ # Load the pre-trained model and preprocessor (feature extractor)
8
+ model_name = "runaksh/chest_xray_pneumonia_detection"
9
+ model = ViTForImageClassification.from_pretrained(model_name)
10
+ feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
11
 
12
+ def classify_image(image):
13
+ # Convert the PIL Image to a format compatible with the feature extractor
14
+ image = np.array(image)
15
+ # Preprocess the image and prepare it for the model
16
+ inputs = feature_extractor(images=image, return_tensors="pt")
17
+ # Make prediction
18
+ with torch.no_grad():
19
+ outputs = model(**inputs)
20
+ logits = outputs.logits
21
+ # Retrieve the highest probability class label index
22
+ predicted_class_idx = logits.argmax(-1).item()
23
+ # Define a manual mapping of label indices to human-readable labels
24
+ index_to_label = {
25
+ 0: "NORMAL",
26
+ 1: "PNEUMONIA"
27
+ }
28
+
29
+ # Convert the index to the model's class label
30
+ label = index_to_label.get(predicted_class_idx, "Unknown Label")
31
+
32
+ return label
33
 
34
  # Create title, description and article strings
35
  title = "Classification Demo"
36
  description = "XRay classification"
37
 
38
+ # Create Gradio interface
39
+ iface = gr.Interface(fn=classify_image,
40
+ inputs=gr.Image(), # Accepts image of any size
41
+ outputs=gr.Label(),
42
+ title=title,
43
+ description=description)
 
 
44
 
45
+ # Launch the app
46
+ iface.launch()