yolac commited on
Commit
fb9e929
·
verified ·
1 Parent(s): 919808e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -45
app.py CHANGED
@@ -1,78 +1,50 @@
1
  import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- from torchvision import transforms
5
  from PIL import Image
 
6
  import logging
7
 
8
  # Set up logging for debugging
9
  logging.basicConfig(level=logging.DEBUG)
10
 
11
- # Define the model architecture that matches the saved .pth file
12
- class BacterialMorphologyClassifier(nn.Module):
13
- def __init__(self):
14
- super(BacterialMorphologyClassifier, self).__init__()
15
- self.feature_extractor = nn.Sequential(
16
- nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
17
- nn.ReLU(),
18
- nn.MaxPool2d(kernel_size=2, stride=2),
19
- nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
20
- nn.ReLU(),
21
- nn.MaxPool2d(kernel_size=2, stride=2),
22
- )
23
- self.fc = nn.Sequential(
24
- nn.Flatten(),
25
- nn.Linear(64 * 56 * 56, 128),
26
- nn.ReLU(),
27
- nn.Dropout(0.5),
28
- nn.Linear(128, 3),
29
- nn.Softmax(dim=1),
30
- )
31
-
32
- def forward(self, x):
33
- x = self.feature_extractor(x)
34
- x = self.fc(x)
35
- return x
36
-
37
- # Load the model and weights
38
  MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.keras"
39
  logging.debug("Starting model loading...")
40
  try:
41
- model = BacterialMorphologyClassifier()
42
- state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
43
- model.load_state_dict(state_dict, strict=False)
44
- model.eval()
45
  logging.debug("Model loaded successfully.")
46
  except Exception as e:
47
  logging.error(f"Error loading the model: {str(e)}")
48
 
49
  # Define image preprocessing transformations
50
- transform = transforms.Compose([
51
- transforms.Resize((224, 224)),
52
- transforms.ToTensor(),
53
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
54
- ])
 
 
 
55
 
56
  # Define the prediction function
57
  def predict(image):
58
  try:
59
- logging.debug("Starting prediction...")
60
  # Preprocess the image
61
- image_tensor = transform(image).unsqueeze(0)
62
  logging.debug("Image preprocessing completed.")
63
 
64
  # Make prediction
65
- output = model(image_tensor)
66
- prediction = output.argmax().item()
67
 
68
  # Class mapping
69
  class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
70
 
71
  # Log prediction details
72
- logging.debug(f"Predicted class: {class_labels[prediction]}, Confidence: {output.max().item()}")
73
 
74
  # Return prediction result
75
- return class_labels[prediction], float(output.max().item())
76
  except Exception as e:
77
  logging.error(f"Error during prediction: {str(e)}")
78
  return "Error", 0.0
 
1
  import gradio as gr
2
+ import tensorflow as tf
 
 
3
  from PIL import Image
4
+ import numpy as np
5
  import logging
6
 
7
  # Set up logging for debugging
8
  logging.basicConfig(level=logging.DEBUG)
9
 
10
+ # Load the .keras model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.keras"
12
  logging.debug("Starting model loading...")
13
  try:
14
+ model = tf.keras.models.load_model(MODEL_PATH)
 
 
 
15
  logging.debug("Model loaded successfully.")
16
  except Exception as e:
17
  logging.error(f"Error loading the model: {str(e)}")
18
 
19
  # Define image preprocessing transformations
20
+ def preprocess_image(image):
21
+ logging.debug("Preprocessing image...")
22
+ image = image.resize((224, 224)) # Resize to match model input size
23
+ image_array = np.array(image) / 255.0 # Normalize pixel values
24
+ if len(image_array.shape) == 2: # If grayscale, convert to RGB
25
+ image_array = np.stack([image_array] * 3, axis=-1)
26
+ image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
27
+ return image_array
28
 
29
  # Define the prediction function
30
  def predict(image):
31
  try:
 
32
  # Preprocess the image
33
+ image_array = preprocess_image(image)
34
  logging.debug("Image preprocessing completed.")
35
 
36
  # Make prediction
37
+ predictions = model.predict(image_array)
38
+ prediction = np.argmax(predictions, axis=1)[0]
39
 
40
  # Class mapping
41
  class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
42
 
43
  # Log prediction details
44
+ logging.debug(f"Predicted class: {class_labels[prediction]}, Confidence: {predictions[0][prediction]}")
45
 
46
  # Return prediction result
47
+ return class_labels[prediction], float(predictions[0][prediction])
48
  except Exception as e:
49
  logging.error(f"Error during prediction: {str(e)}")
50
  return "Error", 0.0