yolac commited on
Commit
1422569
·
verified ·
1 Parent(s): b11be35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -29
app.py CHANGED
@@ -1,42 +1,78 @@
 
 
1
  import gradio as gr
2
- from tensorflow.keras.models import load_model
3
- from huggingface_hub import hf_hub_download
4
- import numpy as np
5
  from PIL import Image
 
 
6
 
7
- # Define constants
8
- MODEL_REPO = "yolac/BacterialMorphologyClassification"
9
- MODEL_FILENAME = "model.keras"
10
- MODEL_PATH = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Load the model
13
- print("Loading model...")
14
- model = load_model(MODEL_PATH)
 
 
15
 
16
- # Preprocessing function
17
- def preprocess_image(image):
18
- image = image.resize((224, 224)) # Adjust size as per your model input
19
- image_array = np.array(image) / 255.0 # Normalize to [0, 1]
20
- image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
21
- return image_array
22
 
23
- # Prediction function
24
  def predict(image):
25
- image_array = preprocess_image(image)
26
- predictions = model.predict(image_array)
27
- class_names = ["Cocci", "Bacilli", "Spirilla"]
28
- predicted_class = class_names[np.argmax(predictions)]
29
- return f"Predicted Class: {predicted_class}"
30
-
31
- # Gradio Interface
32
- interface = gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  fn=predict,
34
- inputs=gr.Image(type="pil"),
35
  outputs="text",
36
- title="Bacterial Morphology Classification",
37
- description="Upload an image of bacteria to classify as Cocci, Bacilli, or Spirilla."
 
38
  )
39
 
40
  # Launch the app
41
- if __name__ == "__main__":
42
- interface.launch()
 
1
+ import torch
2
+ import torch.nn as nn
3
  import gradio as gr
4
+ from torchvision import transforms
 
 
5
  from PIL import Image
6
+ import requests
7
+ from io import BytesIO
8
 
9
+ # Define the PyTorch model architecture
10
+ class MyPyTorchModel(nn.Module):
11
+ def __init__(self):
12
+ super(MyPyTorchModel, self).__init__()
13
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
14
+ self.relu = nn.ReLU()
15
+ self.flatten = nn.Flatten()
16
+ self.fc1 = nn.Linear(32 * 224 * 224, 3) # Adjust output size for 3 classes
17
+
18
+ def forward(self, x):
19
+ x = self.conv1(x)
20
+ x = self.relu(x)
21
+ x = self.flatten(x)
22
+ x = self.fc1(x)
23
+ return x
24
 
25
  # Load the model
26
+ model_path = 'https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth'
27
+ model = MyPyTorchModel()
28
+ model.load_state_dict(torch.load(model_path))
29
+ model.eval()
30
 
31
+ # Define image transformations
32
+ transform = transforms.Compose([
33
+ transforms.Resize((224, 224)),
34
+ transforms.ToTensor(),
35
+ ])
 
36
 
37
+ # Define a function to predict the class of the image
38
  def predict(image):
39
+ image = Image.fromarray(image).convert('RGB')
40
+ image = transform(image).unsqueeze(0)
41
+
42
+ with torch.no_grad():
43
+ outputs = model(image)
44
+ _, predicted = torch.max(outputs, 1)
45
+
46
+ class_labels = ['cocci', 'bacilli', 'spirilla']
47
+ predicted_label = class_labels[predicted.item()]
48
+
49
+ return predicted_label
50
+
51
+ # URLs for 3 example images
52
+ example_image_urls = [
53
+ "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg", # Replace with the actual URL
54
+ "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg", # Replace with the actual URL
55
+ "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg", # Replace with the actual URL
56
+ ]
57
+
58
+ # Function to download and open example images
59
+ def load_example_image(url):
60
+ response = requests.get(url)
61
+ image = Image.open(BytesIO(response.content))
62
+ return image
63
+
64
+ # Load the example images
65
+ example_images = [load_example_image(url) for url in example_image_urls]
66
+
67
+ # Create a Gradio interface
68
+ iface = gr.Interface(
69
  fn=predict,
70
+ inputs=gr.inputs.Image(shape=(224, 224), label="Upload an image or use examples"),
71
  outputs="text",
72
+ title="Bacterial Morphology Classifier",
73
+ description="Classify images of bacteria into cocci, bacilli, or spirilla.",
74
+ examples=example_images # Add example images to the Gradio app
75
  )
76
 
77
  # Launch the app
78
+ iface.launch()