yolac commited on
Commit
5409e53
·
verified ·
1 Parent(s): 884fb83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -66
app.py CHANGED
@@ -1,20 +1,22 @@
1
- import streamlit as st
2
  import torch
3
  from torchvision import transforms
4
  from PIL import Image
5
- import json
6
 
7
- # Load Model
8
- @st.cache_resource
9
  def load_model():
10
  model_path = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
11
  model = torch.load(model_path, map_location=torch.device('cpu'))
12
  model.eval() # Set model to evaluation mode
13
  return model
14
 
15
- # Prediction Function
16
- def predict_image(model, image):
17
- # Transform the image
 
 
 
 
18
  transform = transforms.Compose([
19
  transforms.Resize((224, 224)),
20
  transforms.ToTensor(),
@@ -26,64 +28,29 @@ def predict_image(model, image):
26
  outputs = model(image_tensor)
27
  probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
28
  predicted_class = probabilities.argmax().item()
29
- return predicted_class, probabilities.numpy()
30
-
31
- # Class Labels
32
- def get_class_labels():
33
- # Define your class labels here
34
- return {0: "Cocci", 1: "Bacilli", 2: "Spirilla"}
35
-
36
- # Streamlit App
37
- st.set_page_config(page_title="Bacterial Morphology Classification", page_icon="🦠")
38
-
39
- st.title("🦠 Bacterial Morphology Classification")
40
- st.markdown(
41
- "This app classifies bacterial morphology into **Cocci**, **Bacilli**, or **Spirilla** using a fine-tuned PyTorch model."
42
- )
43
-
44
- # Display example images from local files
45
- st.subheader("Example Images")
46
- example_image_paths = [
47
- "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg",
48
- "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg",
49
- "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"
50
- ]
51
-
52
- for img_path in example_image_paths:
53
- try:
54
- img = Image.open(img_path).convert("RGB")
55
- st.image(img, caption=f"Example Image: {img_path}", use_column_width=True)
56
- except FileNotFoundError:
57
- st.error(f"Example image {img_path} not found. Please ensure the image exists in the app directory.")
58
-
59
- # File Upload
60
- uploaded_file = st.file_uploader("Upload a bacterial image for classification:", type=["jpg", "jpeg", "png"])
61
-
62
- if uploaded_file:
63
- # Display the uploaded image
64
- image = Image.open(uploaded_file).convert("RGB")
65
- st.image(image, caption="Uploaded Image", use_column_width=True)
66
-
67
- # Load Model and Predict
68
- with st.spinner("Classifying..."):
69
- model = load_model()
70
- class_labels = get_class_labels()
71
- predicted_class, probabilities = predict_image(model, image)
72
- predicted_label = class_labels[predicted_class]
73
-
74
- # Display Results
75
- st.success(f"Prediction: **{predicted_label}**")
76
- st.write("Class Probabilities:")
77
- st.json({class_labels[i]: f"{prob:.2%}" for i, prob in enumerate(probabilities)})
78
-
79
- # Sidebar Info
80
- st.sidebar.title("Classifies bacterial images into cocci, bacilli, or spirilla")
81
- st.sidebar.markdown(
82
- """
83
- - **Author**: Yola Charara
84
- - **Dataset**: [Bacterial Morphology Classification](https://huggingface.co/datasets/yolac/BacterialMorphologyClassification)
85
- - **Model**: [MobileNetV2-based Classifier](https://huggingface.co/yolac/BacterialMorphologyClassification)
86
- """
87
- )
88
 
89
 
 
1
+ import gradio as gr
2
  import torch
3
  from torchvision import transforms
4
  from PIL import Image
 
5
 
6
+ # Load the model
 
7
  def load_model():
8
  model_path = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
9
  model = torch.load(model_path, map_location=torch.device('cpu'))
10
  model.eval() # Set model to evaluation mode
11
  return model
12
 
13
+ model = load_model()
14
+
15
+ # Define class labels
16
+ class_labels = {0: "Cocci", 1: "Bacilli", 2: "Spirilla"}
17
+
18
+ # Prediction function
19
+ def predict_image(image):
20
  transform = transforms.Compose([
21
  transforms.Resize((224, 224)),
22
  transforms.ToTensor(),
 
28
  outputs = model(image_tensor)
29
  probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
30
  predicted_class = probabilities.argmax().item()
31
+ return class_labels[predicted_class], {class_labels[i]: float(prob) for i, prob in enumerate(probabilities)}
32
+
33
+ # Example images (these should be in the same directory as your app.py)
34
+ example_images = ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg", "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg", "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"]
35
+
36
+ # Create a Gradio interface
37
+ def create_gradio_interface():
38
+ iface = gr.Interface(
39
+ fn=predict_image,
40
+ inputs=gr.inputs.Image(type="pil", label="Upload an image"),
41
+ outputs=[
42
+ gr.outputs.Label(num_top_classes=3, label="Predicted Class"),
43
+ gr.outputs.JSON(label="Class Probabilities")
44
+ ],
45
+ examples=example_images,
46
+ title="Bacterial Morphology Classification",
47
+ description="This app classifies bacterial morphology into **Cocci**, **Bacilli**, or **Spirilla** using a fine-tuned PyTorch model.",
48
+ )
49
+ return iface
50
+
51
+ # Launch the app
52
+ if __name__ == "__main__":
53
+ app = create_gradio_interface()
54
+ app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56