yolac commited on
Commit
007a13c
·
verified ·
1 Parent(s): 853350a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -11
app.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  import torch.nn as nn
4
  from torchvision import transforms
5
  from PIL import Image
6
- import io
7
  import logging
8
 
9
  # Set up logging for debugging
@@ -38,11 +37,14 @@ class BacterialMorphologyClassifier(nn.Module):
38
  # Load the model and weights
39
  MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
40
  logging.debug("Starting model loading...")
41
- model = BacterialMorphologyClassifier()
42
- state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
43
- model.load_state_dict(state_dict, strict=False)
44
- model.eval()
45
- logging.debug("Model loaded successfully.")
 
 
 
46
 
47
  # Define image preprocessing transformations
48
  transform = transforms.Compose([
@@ -54,8 +56,10 @@ transform = transforms.Compose([
54
  # Define the prediction function
55
  def predict(image):
56
  try:
 
57
  # Preprocess the image
58
  image_tensor = transform(image).unsqueeze(0)
 
59
 
60
  # Make prediction
61
  output = model(image_tensor)
@@ -74,8 +78,13 @@ def predict(image):
74
  return "Error", 0.0
75
 
76
  # Create a Gradio interface
77
- inputs = gr.Image(type="pil", label="Upload an image")
78
- outputs = gr.Label(num_top_classes=3, label="Predicted Class")
79
-
80
- # Launch the Gradio app
81
- gr.Interface(fn=predict, inputs=inputs, outputs=outputs, live=True).launch(server_name="0.0.0.0", server_port=7861, debug=True)
 
 
 
 
 
 
3
  import torch.nn as nn
4
  from torchvision import transforms
5
  from PIL import Image
 
6
  import logging
7
 
8
  # Set up logging for debugging
 
37
  # Load the model and weights
38
  MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
39
  logging.debug("Starting model loading...")
40
+ try:
41
+ model = BacterialMorphologyClassifier()
42
+ state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
43
+ model.load_state_dict(state_dict, strict=False)
44
+ model.eval()
45
+ logging.debug("Model loaded successfully.")
46
+ except Exception as e:
47
+ logging.error(f"Error loading the model: {str(e)}")
48
 
49
  # Define image preprocessing transformations
50
  transform = transforms.Compose([
 
56
  # Define the prediction function
57
  def predict(image):
58
  try:
59
+ logging.debug("Starting prediction...")
60
  # Preprocess the image
61
  image_tensor = transform(image).unsqueeze(0)
62
+ logging.debug("Image preprocessing completed.")
63
 
64
  # Make prediction
65
  output = model(image_tensor)
 
78
  return "Error", 0.0
79
 
80
  # Create a Gradio interface
81
+ gr.Interface(
82
+ fn=predict,
83
+ inputs=gr.Image(type="pil", label="Upload an image"),
84
+ outputs=["text", "number"],
85
+ examples=[
86
+ ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg"],
87
+ ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg"],
88
+ ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"]
89
+ ]
90
+ ).launch(server_name="0.0.0.0", server_port=7862, debug=True)