yolac commited on
Commit
e438cdc
·
verified ·
1 Parent(s): ca8dd76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -87
app.py CHANGED
@@ -1,101 +1,86 @@
 
1
  import torch
2
- import torch.nn as nn
3
  from torchvision import transforms
4
  from PIL import Image
5
- import requests
6
- import gradio as gr
7
- import os
8
 
9
- # Define the model architecture
10
- class BacterialMorphologyClassifier(nn.Module):
11
- def __init__(self):
12
- super(BacterialMorphologyClassifier, self).__init__()
13
- self.feature_extractor = nn.Sequential(
14
- nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
15
- nn.ReLU(),
16
- nn.MaxPool2d(kernel_size=2, stride=2),
17
- nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
18
- nn.ReLU(),
19
- nn.MaxPool2d(kernel_size=2, stride=2),
20
- )
21
- self.fc = nn.Sequential(
22
- nn.Flatten(),
23
- nn.Linear(64 * 56 * 56, 128),
24
- nn.ReLU(),
25
- nn.Dropout(0.5),
26
- nn.Linear(128, 3),
27
- nn.Softmax(dim=1),
28
- )
29
-
30
- def forward(self, x):
31
- x = self.feature_extractor(x)
32
- x = self.fc(x)
33
- return x
34
 
35
- # Load the model
36
- MODEL_PATH = "model.pth"
37
- model = BacterialMorphologyClassifier()
 
 
 
 
 
 
38
 
39
- try:
40
- # Download the model if it doesn't exist
41
- if not os.path.exists(MODEL_PATH):
42
- print("Downloading the model...")
43
- url = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
44
- response = requests.get(url)
45
- with open(MODEL_PATH, "wb") as f:
46
- f.write(response.content)
47
- # Load the model weights
48
- model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
49
- model.eval()
50
- print("Model loaded successfully.")
51
- except Exception as e:
52
- print(f"Error loading the model: {e}")
53
 
54
- # Define image preprocessing to match training preprocessing
55
- transform = transforms.Compose([
56
- transforms.Resize((224, 224)), # Resize to match model input size
57
- transforms.ToTensor(), # Convert to a tensor
58
- transforms.Normalize(mean=[0, 0, 0], std=[1/255, 1/255, 1/255]), # Scale pixel values to [0, 1]
59
- ])
60
 
61
- # Prediction function
62
- def predict(image):
63
- try:
64
- # Convert the image to a tensor
65
- image_tensor = transform(image).unsqueeze(0)
66
-
67
- # Perform prediction
68
- with torch.no_grad(): # Ensure no gradients are calculated
69
- output = model(image_tensor)
70
-
71
- # Class mapping
72
- class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
73
-
74
- # Return the predicted class and confidence
75
- predicted_class = class_labels[output.argmax().item()]
76
- confidence = output.max().item() # Softmax value as confidence
77
- return f"Predicted Class: {predicted_class}\nConfidence: {confidence:.2f}"
78
- except Exception as e:
79
- return f"Error: {str(e)}"
80
 
81
- # Define example images
82
- examples = [
83
- ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg"],
84
- ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg"],
85
- ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"],
 
 
 
 
 
 
86
  ]
87
 
88
- # Set up Gradio interface
89
- interface = gr.Interface(
90
- fn=predict,
91
- inputs=gr.Image(type="pil"),
92
- outputs=gr.Text(label="Prediction"),
93
- title="Bacterial Morphology Classification",
94
- description="Upload an image of bacteria to classify it as cocci, bacilli, or spirilla.",
95
- examples=examples,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
 
98
- # Launch the app
99
- if __name__ == "__main__":
100
- interface.launch()
101
 
 
1
+ import streamlit as st
2
  import torch
 
3
  from torchvision import transforms
4
  from PIL import Image
5
+ import json
 
 
6
 
7
+ # Load Model
8
+ @st.cache_resource
9
+ def load_model():
10
+ model_path = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
11
+ model = torch.load(model_path, map_location=torch.device('cpu'))
12
+ model.eval() # Set model to evaluation mode
13
+ return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Prediction Function
16
+ def predict_image(model, image):
17
+ # Transform the image
18
+ transform = transforms.Compose([
19
+ transforms.Resize((224, 224)),
20
+ transforms.ToTensor(),
21
+ transforms.Lambda(lambda x: x / 255.0) # Rescale pixel values to [0, 1]
22
+ ])
23
+ image_tensor = transform(image).unsqueeze(0) # Add batch dimension
24
 
25
+ with torch.no_grad():
26
+ outputs = model(image_tensor)
27
+ probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
28
+ predicted_class = probabilities.argmax().item()
29
+ return predicted_class, probabilities.numpy()
 
 
 
 
 
 
 
 
 
30
 
31
+ # Class Labels
32
+ def get_class_labels():
33
+ # Define your class labels here
34
+ return {0: "Cocci", 1: "Bacilli", 2: "Spirilla"}
 
 
35
 
36
+ # Streamlit App
37
+ st.set_page_config(page_title="Bacterial Morphology Classification", page_icon="🦠")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ st.title("🦠 Bacterial Morphology Classification")
40
+ st.markdown(
41
+ "This app classifies bacterial morphology into **Cocci**, **Bacilli**, or **Spirilla** using a fine-tuned PyTorch model."
42
+ )
43
+
44
+ # Example Images
45
+ st.subheader("Example Images")
46
+ example_images = [
47
+ "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg", # Replace with actual paths to example images
48
+ "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg",
49
+ "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"
50
  ]
51
 
52
+ for img_path in example_images:
53
+ img = Image.open(img_path).convert("RGB")
54
+ st.image(img, caption=f"Example Image: {img_path}", use_column_width=True)
55
+
56
+ # File Upload
57
+ uploaded_file = st.file_uploader("Upload a bacterial image for classification:", type=["jpg", "jpeg", "png"])
58
+
59
+ if uploaded_file:
60
+ # Display the uploaded image
61
+ image = Image.open(uploaded_file).convert("RGB")
62
+ st.image(image, caption="Uploaded Image", use_column_width=True)
63
+
64
+ # Load Model and Predict
65
+ with st.spinner("Classifying..."):
66
+ model = load_model()
67
+ class_labels = get_class_labels()
68
+ predicted_class, probabilities = predict_image(model, image)
69
+ predicted_label = class_labels[predicted_class]
70
+
71
+ # Display Results
72
+ st.success(f"Prediction: **{predicted_label}**")
73
+ st.write("Class Probabilities:")
74
+ st.json({class_labels[i]: f"{prob:.2%}" for i, prob in enumerate(probabilities)})
75
+
76
+ # Sidebar Info
77
+ st.sidebar.title("Classifies bacterial images into cocci, bacilli, or spirilla")
78
+ st.sidebar.markdown(
79
+ """
80
+ - **Author**: Yola Charara
81
+ - **Dataset**: [Bacterial Morphology Classification](https://huggingface.co/datasets/yolac/BacterialMorphologyClassification)
82
+ - **Model**: [MobileNetV2-based Classifier](https://huggingface.co/yolac/BacterialMorphologyClassification)
83
+ """
84
  )
85
 
 
 
 
86